In [1]:
import os
import sys

os.environ["SPARK_HOME"] = "/usr/spark2.4.3"
os.environ["PYLIB"] = os.environ["SPARK_HOME"] + "/python/lib"
# In below two lines, use /usr/bin/python2.7 if you want to use Python 2
os.environ["PYSPARK_PYTHON"] = "/usr/local/anaconda/bin/python" 
os.environ["PYSPARK_DRIVER_PYTHON"] = "/usr/local/anaconda/bin/python"
sys.path.insert(0, os.environ["PYLIB"] +"/py4j-0.10.7-src.zip")
sys.path.insert(0, os.environ["PYLIB"] +"/pyspark.zip")

In [2]:
#from pyspark import SparkContext, SparkConf
#conf = SparkConf().setAppName("appName")
#sc = SparkContext(conf=conf)

In [3]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("PySpark ML Demo").enableHiveSupport().getOrCreate()

In [4]:
# What version of Spark?
print(spark.version)

2.4.3


# Load Datasets

In [5]:
# Load some airline flight data from a CSV file 'flights.csv'
# The file needs to be uploaded from your local machine to Jupyter Notebook in advance

# In the dataset, fields are separated by a comma and missing data are denoted by the string 'NA'.

# Data dictionary:
# mon — month (integer between 1 and 12)
# dom — day of month (integer between 1 and 31)
# dow — day of week (integer; 1 = Monday and 7 = Sunday)
# org — origin airport (IATA code)
# mile — distance (miles)
# carrier — carrier (IATA code)
# depart — departure time (decimal hour)
# duration — expected duration (minutes)
# delay — delay (minutes)

# Read the flights dataset
# inferSchema: Infer data types of columns automatically
# nullValue: Deal with missing data

flights_original = spark.read.csv('flights.csv',sep=',',header=True,inferSchema=True,nullValue='NA')
# Get number of records. The count() method gives the number of records.
print("The data contain %d records." % flights_original.count())

The data contain 50000 records.


In [6]:
# View the first five records. The show() method displays the first few records.
flights_original.show(5)

+---+---+---+-------+------+---+----+------+--------+-----+
|mon|dom|dow|carrier|flight|org|mile|depart|duration|delay|
+---+---+---+-------+------+---+----+------+--------+-----+
| 11| 20|  6|     US|    19|JFK|2153|  9.48|     351| null|
|  0| 22|  2|     UA|  1107|ORD| 316| 16.33|      82|   30|
|  2| 20|  4|     UA|   226|SFO| 337|  6.17|      82|   -8|
|  9| 13|  1|     AA|   419|ORD|1236| 10.33|     195|   -5|
|  4|  2|  5|     AA|   325|ORD| 258|  8.92|      65| null|
+---+---+---+-------+------+---+----+------+--------+-----+
only showing top 5 rows



In [7]:
# Check column data types. The dtypes attribute gives the column types
print(flights_original.dtypes)

[('mon', 'int'), ('dom', 'int'), ('dow', 'int'), ('carrier', 'string'), ('flight', 'int'), ('org', 'string'), ('mile', 'int'), ('depart', 'double'), ('duration', 'int'), ('delay', 'int')]


In [8]:
# The sms.csv dataset contains a selection of SMS messages which have been classified as
# either 'spam' (1) or 'ham' (0)
# Notes on CSV format:
# no header record and
# fields are separated by a semicolon (this is not the default separator of ',')
# Data dictionary:
# id — record identifier
# text — content of SMS message
# label — spam or ham (integer; 0 = ham and 1 = spam)

from pyspark.sql.types import StructType, StructField, IntegerType, StringType
# Specify column names and types
# Specify the data schema, giving columns names ("id", "text", and "label") and column types.
myschema = StructType([
StructField("id", IntegerType()),  
StructField("text", StringType()),
StructField("label", IntegerType())
])
# Load data from a delimited file
sms_original = spark.read.csv("sms.csv", sep=';', header=False, schema=myschema)

In [9]:
# Print schema of DataFrame
sms_original.printSchema()

root
 |-- id: integer (nullable = true)
 |-- text: string (nullable = true)
 |-- label: integer (nullable = true)



# Data Pre-processing

In [10]:
# We are going to develop a model which will predict whether or not a given flight will be delayed.
# Firstly we need to trim those data down by:
# - removing an uninformative column and
# - removing rows which do not have information about whether or not a flight was delayed.
# Remove the 'flight' column which is irrelevant for prediction
# The drop() method applies to columns only.
flights_drop_column = flights_original.drop('flight')

In [11]:
# Number of records with missing 'delay' values
# Use the filter() method to choose specific rows and
# the count() method to find the number of rows in the result.
flights_drop_column.filter('delay IS NULL').count()

2978

In [12]:
# Remove records with missing 'delay' values
flights_valid_delay = flights_drop_column.filter('delay IS NOT NULL')

In [13]:
# Remove records with missing values in any column and get the number of remaining rows
# The dropna() method will discard all records with any missing fields.
flights_none_missing = flights_valid_delay.dropna()
print("Number of records without any missing values: ", flights_none_missing.count())

Number of records without any missing values:  47022


In [15]:
# The next step of preparing the flight data has two parts:
# - convert the units of distance, replacing the mile column with a kmcolumn; and
# - create a Boolean column indicating whether or not a flight was delayed.
# Import the required function
from pyspark.sql.functions import round

In [16]:
# Convert 'mile' to 'km' and drop 'mile' column
# Use the withColumn() method to manipulate columns.
flights_adding_km = flights_none_missing.withColumn('km', round(flights_original.mile * 1.60934, 0)).drop('mile')

In [17]:
# Create 'label' column indicating whether flight delayed (1) or not (0)
# a flight to be "delayed" when it arrives 15 minutes or more after its scheduled time.
flights_adding_km = flights_adding_km.withColumn('label', (flights_adding_km.delay >= 15).cast('integer'))

In [18]:
# Check first five records
flights_adding_km.show(5)

+---+---+---+-------+---+------+--------+-----+------+-----+
|mon|dom|dow|carrier|org|depart|duration|delay|    km|label|
+---+---+---+-------+---+------+--------+-----+------+-----+
|  0| 22|  2|     UA|ORD| 16.33|      82|   30| 509.0|    1|
|  2| 20|  4|     UA|SFO|  6.17|      82|   -8| 542.0|    0|
|  9| 13|  1|     AA|ORD| 10.33|     195|   -5|1989.0|    0|
|  5|  2|  1|     UA|SFO|  7.98|     102|    2| 885.0|    0|
|  7|  2|  6|     AA|ORD| 10.83|     135|   54|1180.0|    1|
+---+---+---+-------+---+------+--------+-----+------+-----+
only showing top 5 rows



In [21]:
# In the flights data there are two columns, carrier and org, which hold categorical data.
# You need to transform those columns into indexed numerical values.
# Machine Learning model needs numbers not strings, so these transformations are vital!
# Import the appropriate class and create an indexer object to
# transform the carrier column from a string to an numeric index
from pyspark.ml.feature import StringIndexer

In [22]:
# Create an indexer; Prepare the indexer object on the flight data.
indexer = StringIndexer(inputCol='carrier', outputCol='carrier_idx')

In [23]:
# Indexer identifies categories in the data
indexer_model = indexer.fit(flights_adding_km)

In [24]:
# Indexer creates a new column with numeric index values
flights_indexed = indexer_model.transform(flights_adding_km)

In [25]:
# Repeat the process for the other categorical feature
flights_indexed = StringIndexer(inputCol='org', outputCol='org_idx').fit(flights_indexed).transform(flights_indexed)

In [28]:
print(flights_indexed.show(15))

+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+
|mon|dom|dow|carrier|org|depart|duration|delay|    km|label|carrier_idx|org_idx|
+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+
|  0| 22|  2|     UA|ORD| 16.33|      82|   30| 509.0|    1|        0.0|    0.0|
|  2| 20|  4|     UA|SFO|  6.17|      82|   -8| 542.0|    0|        0.0|    1.0|
|  9| 13|  1|     AA|ORD| 10.33|     195|   -5|1989.0|    0|        1.0|    0.0|
|  5|  2|  1|     UA|SFO|  7.98|     102|    2| 885.0|    0|        0.0|    1.0|
|  7|  2|  6|     AA|ORD| 10.83|     135|   54|1180.0|    1|        1.0|    0.0|
|  1| 16|  6|     UA|ORD|   8.0|     232|   -7|2317.0|    0|        0.0|    0.0|
|  1| 22|  5|     UA|SJC|  7.98|     250|  -13|2943.0|    0|        0.0|    5.0|
| 11|  8|  1|     OO|SFO|  7.77|      60|   88| 254.0|    1|        2.0|    1.0|
|  4| 26|  1|     AA|SFO| 13.25|     210|  -10|2356.0|    0|        1.0|    1.0|
|  4| 25|  0|     AA|ORD| 13

In [29]:
# The final stage of data preparation is to consolidate all of the predictor columns into a single column.
# An updated version of the flights data, which takes into account all of the changes,
# has the following predictor columns:
# - mon, dom and dow
# - carrier_idx (indexed value from carrier)
# - org_idx (indexed value from org)
# - km
# - depart
# - duration

# Import the class which will assemble the predictors.
from pyspark.ml.feature import VectorAssembler

In [32]:
# Create an assembler object that can merge the predictors columns into a single column.
assembler = VectorAssembler(inputCols=[
'mon', 'dom', 'dow', 'carrier_idx', 'org_idx', 'km', 'depart', 'duration'
], outputCol='features')

In [33]:
# Consolidate predictor columns
flights_assembled = assembler.transform(flights_indexed)

In [34]:
# Check the resulting column
flights_assembled.select('features', 'delay').show(5, truncate=False)

+-----------------------------------------+-----+
|features                                 |delay|
+-----------------------------------------+-----+
|[0.0,22.0,2.0,0.0,0.0,509.0,16.33,82.0]  |30   |
|[2.0,20.0,4.0,0.0,1.0,542.0,6.17,82.0]   |-8   |
|[9.0,13.0,1.0,1.0,0.0,1989.0,10.33,195.0]|-5   |
|[5.0,2.0,1.0,0.0,1.0,885.0,7.98,102.0]   |2    |
|[7.0,2.0,6.0,1.0,0.0,1180.0,10.83,135.0] |54   |
+-----------------------------------------+-----+
only showing top 5 rows



In [35]:
flights_assembled.show(5)

+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+--------------------+
|mon|dom|dow|carrier|org|depart|duration|delay|    km|label|carrier_idx|org_idx|            features|
+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+--------------------+
|  0| 22|  2|     UA|ORD| 16.33|      82|   30| 509.0|    1|        0.0|    0.0|[0.0,22.0,2.0,0.0...|
|  2| 20|  4|     UA|SFO|  6.17|      82|   -8| 542.0|    0|        0.0|    1.0|[2.0,20.0,4.0,0.0...|
|  9| 13|  1|     AA|ORD| 10.33|     195|   -5|1989.0|    0|        1.0|    0.0|[9.0,13.0,1.0,1.0...|
|  5|  2|  1|     UA|SFO|  7.98|     102|    2| 885.0|    0|        0.0|    1.0|[5.0,2.0,1.0,0.0,...|
|  7|  2|  6|     AA|ORD| 10.83|     135|   54|1180.0|    1|        1.0|    0.0|[7.0,2.0,6.0,1.0,...|
+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+--------------------+
only showing top 5 rows



# Classification using Decision Tree

In [36]:
# We will split the data into two components:
# - training data (used to train the model) and
# - testing data (used to test the model).
# Split into training and testing sets in a 80:20 ratio
flights_train, flights_test = flights_assembled.randomSplit([0.8, 0.2], seed=17)
print(flights_train.show(5))

+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+--------------------+
|mon|dom|dow|carrier|org|depart|duration|delay|    km|label|carrier_idx|org_idx|            features|
+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+--------------------+
|  0|  1|  2|     AA|JFK|  6.58|     230|   50|2570.0|    1|        1.0|    2.0|[0.0,1.0,2.0,1.0,...|
|  0|  1|  2|     AA|JFK|   7.0|     385|  -16|4162.0|    0|        1.0|    2.0|[0.0,1.0,2.0,1.0,...|
|  0|  1|  2|     AA|JFK|  12.0|     370|   11|3983.0|    0|        1.0|    2.0|[0.0,1.0,2.0,1.0,...|
|  0|  1|  2|     AA|JFK|  17.0|     379|  -10|3983.0|    0|        1.0|    2.0|[0.0,1.0,2.0,1.0,...|
|  0|  1|  2|     AA|LGA|  8.25|     250|   27|2235.0|    1|        1.0|    3.0|[0.0,1.0,2.0,1.0,...|
+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+--------------------+
only showing top 5 rows

None


In [37]:
# Check that training set has around 80% of records
training_ratio = flights_train.count() / flights_assembled.count()
print(training_ratio)

0.7980732423121092


In [39]:
# Import the Decision Tree Classifier class
from pyspark.ml.classification import DecisionTreeClassifier

# Create a classifier object and fit to the training data
tree = DecisionTreeClassifier()
tree_model = tree.fit(flights_train)

# Create predictions for the testing data and take a look at the predictions
prediction = tree_model.transform(flights_test)
prediction.select('label', 'prediction', 'probability').show(5, False)

+-----+----------+---------------------------------------+
|label|prediction|probability                            |
+-----+----------+---------------------------------------+
|1    |1.0       |[0.4931950745301361,0.5068049254698639]|
|1    |1.0       |[0.35528564453125,0.64471435546875]    |
|1    |1.0       |[0.35528564453125,0.64471435546875]    |
|1    |1.0       |[0.35528564453125,0.64471435546875]    |
|1    |1.0       |[0.35528564453125,0.64471435546875]    |
+-----+----------+---------------------------------------+
only showing top 5 rows



In [40]:
# Evaluate the decision tree using confusion matrix
# A confusion matrix gives a useful breakdown of predictions versus known values.
# It has four cells which represent the counts of:
# - True Negatives (TN): model predicts negative outcome & known outcome is negative
# - True Positives (TP): model predicts positive outcome & known outcome is positive
# - False Negatives (FN): model predicts negative outcome but known outcome is positive
# - False Positives (FP): model predicts positive outcome but known outcome is negative.
# Create a confusion matrix by counting the combinations of label and prediction. Display the result.
prediction.groupBy('label', 'prediction').count().show()
# Count # of True Negatives, True Positives, False Negatives and False Positives in confusion matrix
# Use the predicatea:
# - prediction = 0 AND label = prediction (TF)
# - prediction = 1 AND label = prediction (TP)
# - prediction = 0 AND label != prediction (FN)
# - prediction = 1 AND label != prediction (FP)
TN = prediction.filter('prediction = 0 AND label = prediction').count()
TP = prediction.filter('prediction = 1 AND label = prediction').count()
FN = prediction.filter('prediction = 0 AND label != prediction').count()
FP = prediction.filter('prediction = 1 AND label != prediction').count()
# Accuracy measures the proportion of correct predictions
# The accuracy is the ratio of correct predictions (TP and TN) to all predictions (TP, TN, FP and FN)
accuracy = (TN + TP) / (TN + TP + FN + FP)
print(accuracy)

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|    1|       0.0| 1201|
|    0|       0.0| 2411|
|    1|       1.0| 3623|
|    0|       1.0| 2260|
+-----+----------+-----+

0.635492364402317


# Classification using Logistic Regression

In [42]:
### Use logistic regression to predict whether a flight is likely to be delayed by
# at least 15 minutes (label 1) or not (label 0).
# Import the logistic regression class
from pyspark.ml.classification import LogisticRegression
# Create a classifier object and train on training data
logistic = LogisticRegression().fit(flights_train)

print(flights_train.show(5))

+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+--------------------+
|mon|dom|dow|carrier|org|depart|duration|delay|    km|label|carrier_idx|org_idx|            features|
+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+--------------------+
|  0|  1|  2|     AA|JFK|  6.58|     230|   50|2570.0|    1|        1.0|    2.0|[0.0,1.0,2.0,1.0,...|
|  0|  1|  2|     AA|JFK|   7.0|     385|  -16|4162.0|    0|        1.0|    2.0|[0.0,1.0,2.0,1.0,...|
|  0|  1|  2|     AA|JFK|  12.0|     370|   11|3983.0|    0|        1.0|    2.0|[0.0,1.0,2.0,1.0,...|
|  0|  1|  2|     AA|JFK|  17.0|     379|  -10|3983.0|    0|        1.0|    2.0|[0.0,1.0,2.0,1.0,...|
|  0|  1|  2|     AA|LGA|  8.25|     250|   27|2235.0|    1|        1.0|    3.0|[0.0,1.0,2.0,1.0,...|
+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+--------------------+
only showing top 5 rows

None


In [43]:
# Create predictions for the testing data and show confusion matrix
prediction = logistic.transform(flights_test)
prediction.groupBy('label', 'prediction').count().show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|    1|       0.0| 1652|
|    0|       0.0| 2645|
|    1|       1.0| 3172|
|    0|       1.0| 2026|
+-----+----------+-----+



In [44]:
## Evaluate the logistic regression model.
# Accuracy is generally not a very reliable metric because it can be biased by the most common target class
# There are two other useful metrics:
# - Precision is the proportion of positive predictions which are correct, ie. TP/(TP+FP)
# - Recall is the proportion of positives outcomes which are correctly predicted, ie. TP/(TP+FN)
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator
# Calculate precision and recall
precision = TP / (TP + FP)
recall = TP / (TP + FN)
print('precision = {:.2f}\nrecall = {:.2f}'.format(precision, recall))

precision = 0.62
recall = 0.75


In [46]:
# Find weighted precision.
# The weighted precision indicates what proportion of predictions (positive and negative) are correct.
# Create a multi-class evaluator and evaluate weighted precision.
# The metric name is "weightedPrecision".
multi_evaluator = MulticlassClassificationEvaluator()
weighted_precision = multi_evaluator.evaluate(prediction, {multi_evaluator.metricName: "weightedPrecision"})
# Find AUC
# Create a binary evaluator and evaluate AUC using the "areaUnderROC" metric.
binary_evaluator = BinaryClassificationEvaluator()
auc = binary_evaluator.evaluate(prediction, {binary_evaluator.metricName: "areaUnderROC"})
print(auc)

0.6501216177018289


In [47]:
## Another example of classification using logistic regression for the sms dataset
# Firstly to repare the SMS messages as follows:
# - remove punctuation and numbers
# - tokenize (split into individual words)
# - remove stop words
# - apply the hashing trick
# - convert to TF-IDF representation.
# Import the necessary functions
from pyspark.sql.functions import regexp_replace
from pyspark.ml.feature import Tokenizer

# Use regular expressions (or REGEX) to remove the punctuation symbols.
# Replace all punctuation characters from the text column with a space.
# Do the same for all numbers in the text column.
wrangled = sms_original.withColumn('text', regexp_replace(sms_original.text, '[_():;,.!?\\-]', ' '))
wrangled = wrangled.withColumn('text', regexp_replace(wrangled.text, '[0-9]', ' '))
# Merge multiple spaces
sms_cleaned = wrangled.withColumn('text', regexp_replace(wrangled.text, ' +', ' '))
# Split the text into words
# Split the 'text' column into tokens. Name the output column 'words'
sms_tokenized = Tokenizer(inputCol='text', outputCol='words').transform(sms_cleaned)
sms_tokenized.show(4, truncate=False)

+---+----------------------------------+-----+------------------------------------------+
|id |text                              |label|words                                     |
+---+----------------------------------+-----+------------------------------------------+
|1  |Sorry I'll call later in meeting  |0    |[sorry, i'll, call, later, in, meeting]   |
|2  |Dont worry I guess he's busy      |0    |[dont, worry, i, guess, he's, busy]       |
|3  |Call FREEPHONE now                |1    |[call, freephone, now]                    |
|4  |Win a cash prize or a prize worth |1    |[win, a, cash, prize, or, a, prize, worth]|
+---+----------------------------------+-----+------------------------------------------+
only showing top 4 rows



In [39]:
# Evaluate the decision tree using confusion matrix
# A confusion matrix gives a useful breakdown of predictions versus known values.
# It has four cells which represent the counts of:
# - True Negatives (TN): model predicts negative outcome & known outcome is negative
# - True Positives (TP): model predicts positive outcome & known outcome is positive
# - False Negatives (FN): model predicts negative outcome but known outcome is positive
# - False Positives (FP): model predicts positive outcome but known outcome is negative.
# Create a confusion matrix by counting the combinations of label and prediction. Display the result.
prediction.groupBy('label', 'prediction').count().show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|    1|       0.0| 1230|
|    0|       0.0| 2465|
|    1|       1.0| 3594|
|    0|       1.0| 2206|
+-----+----------+-----+



In [41]:
# Count # of True Negatives, True Positives, False Negatives and False Positives in confusion matrix
# Use the predicatea:
# - prediction = 0 AND label = prediction (TF)
# - prediction = 1 AND label = prediction (TP)
# - prediction = 0 AND label != prediction (FN)
# - prediction = 1 AND label != prediction (FP)
TN = prediction.filter('prediction = 0 AND label = prediction').count()
TP = prediction.filter('prediction = 1 AND label = prediction').count()
FN = prediction.filter('prediction = 0 AND label != prediction').count()
FP = prediction.filter('prediction = 1 AND label != prediction').count()

# Accuracy measures the proportion of correct predictions
# The accuracy is the ratio of correct predictions (TP and TN) to all predictions (TP, TN, FP and FN)
accuracy = (TN + TP) / (TN + TP + FN + FP)
print(accuracy)

0.6381253291205898


In [48]:
from pyspark.ml.feature import StopWordsRemover, HashingTF, IDF
# Remove stop words - to eliminate so commonly used words that carry very little useful info.
# StopWordsRemover class contains a list of stop words which can be customized if necessary.
sms_without_stop = StopWordsRemover(inputCol='words', outputCol='terms').transform(sms_tokenized)
# Apply the hashing trick
# The hashing trick provides a fast and space-efficient way to
# map a very large (possibly infinite) set of items (in this case, all words contained in the SMS messages)
# onto a smaller, finite number of values.
sms_hashed = HashingTF(inputCol='terms', outputCol='hash', numFeatures=1024).transform(sms_without_stop)

In [50]:
# Convert hashed symbols to TF-IDF representation
# The TF-IDF matrix reflects how important a word is to each document.
# It takes into account both the frequency of the word within each document but also
# the frequency of the word across all of the documents in the collection.
# ie. Weight the number of counts for a word in a particular document against
# how frequently that word occurs across all documents
sms_tfidf = IDF(inputCol='hash', outputCol='features').fit(sms_hashed).transform(sms_hashed)
sms_tfidf.select('terms', 'features').show(4, truncate=False)

+--------------------------------+----------------------------------------------------------------------------------------------------+
|terms                           |features                                                                                            |
+--------------------------------+----------------------------------------------------------------------------------------------------+
|[sorry, call, later, meeting]   |(1024,[138,344,378,1006],[2.2391682769656747,2.892706319430574,3.684405173719015,4.244020961654438])|
|[dont, worry, guess, busy]      |(1024,[53,233,329,858],[4.618714411095849,3.557143394108088,4.618714411095849,4.937168142214383])   |
|[call, freephone]               |(1024,[138,396],[2.2391682769656747,3.3843005812686773])                                            |
|[win, cash, prize, prize, worth]|(1024,[31,69,387,428],[3.7897656893768414,7.284881949239966,4.4671645129686475,3.898659777615979])  |
+--------------------------------+--------------

In [51]:
# Split the tf_idf data into training and testing sets in a 4:1 ratio
sms_train, sms_test = sms_tfidf.randomSplit([0.8, 0.2], seed=13)
# Fit a Logistic Regression model to the training data
logistic = LogisticRegression(regParam=0.2).fit(sms_train)
# Make predictions on the testing data
prediction = logistic.transform(sms_test)
# Create a confusion matrix, comparing predictions to known labels
prediction.groupBy('label', 'prediction').count().show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|    1|       0.0|   47|
|    0|       0.0|  987|
|    1|       1.0|  124|
|    0|       1.0|    3|
+-----+----------+-----+

