# Text Classification Demo

## Text Classification

The goal of text classifcation is the classification of documents into a fixed number of predefined categories. In this notebook, we will walk through an example of text classifcation against a well known text classification data set using Spark machine learning algorithms. We will specifically classify the documents into two categories - a binary classification.

## Spark MLlib

MLlib is Spark’s machine learning (ML) library. Its goal is to make practical machine learning scalable and easy. It consists of common learning algorithms and utilities, including classification, regression, clustering, collaborative filtering, dimensionality reduction, as well as lower-level optimization primitives and higher-level pipeline APIs. In this notebook, we will explore how to employ feature extraction transformations and classification algorithms for document classification as well as how to combine these into a single pipeline or workflow. In this example, we will specifically utilize the spark.ml package of Spark MLlib as it provides a high-level API built on top of DataFrames, a Spark abstraction layer that provides for a distributed collection of data organized into named columns, for constructing ML pipelines.

## Data Set

The 20 Newsgroups data set is a collection of approximately 20,000 newsgroup documents, partitioned (nearly) evenly across 20 different newsgroups, each corresponding to a different topic. The 20 newsgroups collection has become a popular data set for experiments in text applications of machine learning techniques, such as text classification and text clustering. In this demo, we will only use a subset of the 20 Newsgroups data set consisting of 2000 articles - 100 articles from each of the 20 newsgroups. Some of the newsgroups are very closely related to each other (e.g. comp.sys.ibm.pc.hardware / comp.sys.mac.hardware), while others are highly unrelated (e.g misc.forsale / soc.religion.christian). Here is a list of the 20 newsgroups, partitioned (more or less) according to subject matter:
  
* comp.graphics
* comp.os.ms-windows.misc
* comp.sys.ibm.pc.hardware
* comp.sys.mac.hardware
* comp.windows.x
* rec.autos
* rec.motorcycles
* rec.sport.baseball
* rec.sport.hockey
* sci.crypt
* sci.electronics
* sci.med
* sci.space
* misc.forsale
* talk.politics.misc
* talk.politics.guns
* talk.politics.mideast
* talk.religion.misc
* alt.atheism
* soc.religion.christian
  
Acknowledgement: Hettich, S. and Bay, S. D. (1999). The UCI KDD Archive [http://kdd.ics.uci.edu]. Irvine, CA: University of California, Department of Information and Computer Science.

## Objective

In this exercise we are going to train a model to classify documents from the 20 Newsgroups data set into two categories according to whether or not the documents are computer related. We will then evaluate and tune the model against a test data set with documents that the model was not trained against.

## One other thing to note

In this notebook, investigation of the DataFrame objects is illustrated with both the DataFrame API and SQL - after first registering the DataFrame as a temporary table. Registering a DataFrame as a table allows you to run SQL queries over its data. Showing both DataFrame and SQL access is strictly done for illustrative purposes.

## Load the data set
#### Obtain a subset of the 20 Newsgroups data set
A tarball of the 2000 document subset of the 20 Newsgroups data can be found at https://kdd.ics.uci.edu/databases/20newsgroups/mini_newsgroups.tar.gz.


In [1]:
import sys.process._
"rm -f mini_newsgroups.tar.gz".!
"wget https://kdd.ics.uci.edu/databases/20newsgroups/mini_newsgroups.tar.gz".!

--2016-10-18 19:39:54--  https://kdd.ics.uci.edu/databases/20newsgroups/mini_newsgroups.tar.gz
Resolving kdd.ics.uci.edu (kdd.ics.uci.edu)... 128.195.1.95
Connecting to kdd.ics.uci.edu (kdd.ics.uci.edu)|128.195.1.95|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1860687 (1.8M) [application/x-gzip]
Saving to: 'mini_newsgroups.tar.gz'

     0K .......... .......... .......... .......... ..........  2%  798K 2s
    50K .......... .......... .......... .......... ..........  5%  529K 3s
   100K .......... .......... .......... .......... ..........  8%  783K 2s
   150K .......... .......... .......... .......... .......... 11%  529K 3s
   200K .......... .......... .......... .......... .......... 13%  789K 2s
   250K .......... .......... .......... .......... .......... 16%  790K 2s
   300K .......... .......... .......... .......... .......... 19%  527K 2s
   350K .......... .......... .......... .......... .......... 22%  789K 2s
   400K .......... ..........

0

#### Explode the tarball

The result is 20 directories corresponding to each of the 20 newsgroups topics. Each directory contains 100 documents stored as files according to topic.

In [2]:
"rm -rf mini_newsgroups".!
"tar -zxf mini_newsgroups.tar.gz".!

0

#### Show the resulting directory structure

In [3]:
"ls -l mini_newsgroups".!

0

total 0
drwxr-xr-x 2 s794-746ee4c5b0e6ab-a5f39cf201a0 users 4096 Apr 20  1997 alt.atheism
drwxr-xr-x 2 s794-746ee4c5b0e6ab-a5f39cf201a0 users 4096 Apr 20  1997 comp.graphics
drwxr-xr-x 2 s794-746ee4c5b0e6ab-a5f39cf201a0 users 4096 Apr 20  1997 comp.os.ms-windows.misc
drwxr-xr-x 2 s794-746ee4c5b0e6ab-a5f39cf201a0 users 4096 Apr 20  1997 comp.sys.ibm.pc.hardware
drwxr-xr-x 2 s794-746ee4c5b0e6ab-a5f39cf201a0 users 4096 Apr 20  1997 comp.sys.mac.hardware
drwxr-xr-x 2 s794-746ee4c5b0e6ab-a5f39cf201a0 users 4096 Apr 20  1997 comp.windows.x
drwxr-xr-x 2 s794-746ee4c5b0e6ab-a5f39cf201a0 users 4096 Apr 20  1997 misc.forsale
drwxr-xr-x 2 s794-746ee4c5b0e6ab-a5f39cf201a0 users 4096 Apr 20  1997 rec.autos
drwxr-xr-x 2 s794-746ee4c5b0e6ab-a5f39cf201a0 users 4096 Apr 20  1997 rec.motorcycles
drwxr-xr-x 2 s794-746ee4c5b0e6ab-a5f39cf201a0 users 4096 Apr 20  1997 rec.sport.baseball
drwxr-xr-x 2 s794-746ee4c5b0e6ab-a5f39cf201a0 users 4096 Apr 20  1997 rec.sport.hockey
drwxr-xr-x 2 s794-746ee4c5b0e6ab-a5

## Spark Setup
#### Validate Spark context and create a Spark SQL context
A Spark context is already defined in the notebook.

In [4]:
println("Spark version = " + sc.version)
val sqlContext= new org.apache.spark.sql.SQLContext(sc)
println("Spark SQL context: " + sqlContext)
import sqlContext.implicits._

Spark version = 1.6.0
Spark SQL context: org.apache.spark.sql.SQLContext@bca0929c


#### Import required machine learning libraries

In [5]:
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, StopWordsRemover, IDF, Tokenizer}
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.mllib.linalg.Vector

## Read in the newsgroups documents

wholeTextFiles lets you read a directory structure containing multiple small text files and returns each of them as (filepath, content) pairs. We will do this for each topic and union the resulting RDDs together.

In [6]:
val path = "mini_newsgroups/*"
val newsgroupsRawData = sc.wholeTextFiles(path)

#### Show a count of the number of documents read in

In [7]:
println("The number of documents read in is " + newsgroupsRawData.count() + ".")

The number of documents read in is 2000.


#### Look at a sample (filepath, content) pair

In [10]:
newsgroupsRawData.takeSample(false, 1, 10L).foreach(println)

(file:/gpfs/global_fs01/sym_shared/YPProdSpark/user/s794-746ee4c5b0e6ab-a5f39cf201a0/notebook/work/mini_newsgroups/rec.motorcycles/105207,Path: cantaloupe.srv.cs.cmu.edu!rochester!udel!darwin.sura.net!howland.reston.ans.net!usc!news.service.uci.edu!cerritos.edu!tanner
From: tanner@cerritos.edu
Newsgroups: rec.motorcycles
Subject: Re: Posted Gif of BMW R100S
Message-ID: <1993Apr26.020339.8132@cerritos.edu>
Date: 26 Apr 93 02:03:39 PST
References: <1993Apr22.201652.17882@news.columbia.edu>
Organization: Cerritos College, Norwalk CA
Lines: 26

> 	If any would care to see any more close-ups or different angles, I can
> 	post others to a.b.p also. I would be happy to submit one to cerritos
> 	if someone wants to write me and tell me how...

I would prefer a picture with you in it.  Since most motorcycles don't post,
and are rather similar looking (i.e all R100S's are more alike than they are
different), it is the people that are ultimately more interesting.

From archive_policy.txt:
> If yo

#### Strip out the filepath from the (filepath, content) pair
Remember that wholeTextFiles returns (filepath, content) pairs.

In [11]:
val filepath = newsgroupsRawData.map{case(filepath,text) => (filepath)}

#### Look at some sample filepaths

In [12]:
filepath.takeSample(false, 5, 10L).foreach(println)

file:/gpfs/global_fs01/sym_shared/YPProdSpark/user/s794-746ee4c5b0e6ab-a5f39cf201a0/notebook/work/mini_newsgroups/sci.electronics/54244
file:/gpfs/global_fs01/sym_shared/YPProdSpark/user/s794-746ee4c5b0e6ab-a5f39cf201a0/notebook/work/mini_newsgroups/talk.politics.guns/55231
file:/gpfs/global_fs01/sym_shared/YPProdSpark/user/s794-746ee4c5b0e6ab-a5f39cf201a0/notebook/work/mini_newsgroups/rec.motorcycles/104974
file:/gpfs/global_fs01/sym_shared/YPProdSpark/user/s794-746ee4c5b0e6ab-a5f39cf201a0/notebook/work/mini_newsgroups/sci.crypt/15702
file:/gpfs/global_fs01/sym_shared/YPProdSpark/user/s794-746ee4c5b0e6ab-a5f39cf201a0/notebook/work/mini_newsgroups/talk.politics.misc/176982


#### Extract the document text from the (filepath, content) pair

In [13]:
val text = newsgroupsRawData.map{case(filepath,text) => text}

#### Validate that just the text was extracted from the (filepath, content) pair
Note that no filepath information is included.

In [14]:
text.takeSample(false, 1, 10L).foreach(println)

Path: cantaloupe.srv.cs.cmu.edu!rochester!udel!darwin.sura.net!howland.reston.ans.net!usc!news.service.uci.edu!cerritos.edu!tanner
From: tanner@cerritos.edu
Newsgroups: rec.motorcycles
Subject: Re: Posted Gif of BMW R100S
Message-ID: <1993Apr26.020339.8132@cerritos.edu>
Date: 26 Apr 93 02:03:39 PST
References: <1993Apr22.201652.17882@news.columbia.edu>
Organization: Cerritos College, Norwalk CA
Lines: 26

> 	If any would care to see any more close-ups or different angles, I can
> 	post others to a.b.p also. I would be happy to submit one to cerritos
> 	if someone wants to write me and tell me how...

I would prefer a picture with you in it.  Since most motorcycles don't post,
and are rather similar looking (i.e all R100S's are more alike than they are
different), it is the people that are ultimately more interesting.

From archive_policy.txt:
> If you already have a picture in some machine-readable format (GIF preferred),
> you can FTP it to Cerritos.edu account 'anonymous' password 'i

#### Extract the filename from the full filepath
For example, extract the filename '54200' from the full filepath of '/resources/data/mini_newsgroups/talk.politics.guns/54200'.

In [15]:
val id = filepath.map(filepath => (filepath.split("/").takeRight(1))(0))

#### Show that filenames have been extracted from the full filepath

In [16]:
id.take(5).foreach(println)

53633
54244
53150
54237
53490


#### Extract the topic from the filepath
Documents in the data set are stored in directories according to topic. The lowest level directory represents the topic classification.

In [17]:
val topic = filepath.map (filepath => (filepath.split("/").takeRight(2))(0))

#### Validate that topics were extracted from the filepath

In [18]:
topic.distinct().take(20).foreach(println)

sci.electronics
rec.motorcycles
comp.graphics
alt.atheism
comp.windows.x
rec.sport.baseball
talk.politics.mideast
rec.autos
talk.politics.guns
misc.forsale
sci.space
sci.crypt
comp.sys.ibm.pc.hardware
soc.religion.christian
sci.med
talk.politics.misc
talk.religion.misc
comp.sys.mac.hardware
rec.sport.hockey
comp.os.ms-windows.misc


## Put the data into a DataFrame
#### Define a case class and convert to a DataFrame

In [19]:
case class newsgroupsCaseClass(id: String, text: String, topic: String)

val newsgroups = newsgroupsRawData.map{case (filepath, text) => 
    val id = filepath.split("/").takeRight(1)(0)
    val topic = filepath.split("/").takeRight(2)(0)
    newsgroupsCaseClass(id, text, topic)}.toDF()
newsgroups.cache()

[id: string, text: string, topic: string]

#### Show the DataFrame schema and display 5 rows of the DataFrame

In [23]:
newsgroups.printSchema()
newsgroups.sample(false,0.005,10L).show(5)

root
 |-- id: string (nullable = true)
 |-- text: string (nullable = true)
 |-- topic: string (nullable = true)

+------+--------------------+------------------+
|    id|                text|             topic|
+------+--------------------+------------------+
| 54160|Path: cantaloupe....|       alt.atheism|
|101624|Newsgroups: rec.a...|         rec.autos|
|103667|Newsgroups: rec.a...|         rec.autos|
| 55278|Xref: cantaloupe....|talk.politics.guns|
| 53772|Newsgroups: sci.e...|   sci.electronics|
+------+--------------------+------------------+
only showing top 5 rows



#### Show document count by topic
##### using DataFrame API

In [24]:
newsgroups.groupBy("topic").count().show()

+--------------------+-----+
|               topic|count|
+--------------------+-----+
|    rec.sport.hockey|  100|
|     sci.electronics|  100|
|             sci.med|  100|
|           rec.autos|  100|
|comp.sys.mac.hard...|  100|
|      comp.windows.x|  100|
|  rec.sport.baseball|  100|
|comp.sys.ibm.pc.h...|  100|
|        misc.forsale|  100|
|     rec.motorcycles|  100|
|           sci.crypt|  100|
|  talk.politics.misc|  100|
|       comp.graphics|  100|
|         alt.atheism|  100|
|talk.politics.mid...|  100|
|soc.religion.chri...|  100|
|  talk.politics.guns|  100|
|comp.os.ms-window...|  100|
|           sci.space|  100|
|  talk.religion.misc|  100|
+--------------------+-----+



#### Show only documents that are related to computer topics
##### That is have "comp" in the topic name
###### Using DataFrame API

In [35]:
newsgroups.filter(newsgroups("topic").like("comp%")).sample(false,0.01,10L).show(5)

+-----+--------------------+--------------------+
|   id|                text|               topic|
+-----+--------------------+--------------------+
| 9519|Newsgroups: comp....|comp.os.ms-window...|
| 9814|Xref: cantaloupe....|comp.os.ms-window...|
|39078|Xref: cantaloupe....|       comp.graphics|
|38839|Newsgroups: comp....|       comp.graphics|
|38244|Newsgroups: comp....|       comp.graphics|
+-----+--------------------+--------------------+



## Training the Model
This demo will use a Spark MLlib Logistic Regression algorithm to classify the documents into topics. The Logistic Regression method requires a numeric label of type double and can not work directly with the text topic categories that we extracted from input dataset. As stated above, the goal of this exercise to to classify documents in terms of whether they are computer related or not. As we saw above, the documents that are computer related reside in directories that begin with "comp". What we are now going to do is assing a numeric column called 'label' with a value of 0 for all non-computer related documents (those without "comp" in the topic name) and a value of 1 for all computer related documents (those with "comp" in the topic name).

In [36]:
val labelednewsgroups = newsgroups.withColumn("label", newsgroups("topic").like("comp%").cast("double"))

#### Show the label column
###### Using the DataFrame API
Label is set to 0 for non-computer related topics and 1 for computer related topics

In [44]:
labelednewsgroups.sample(false,0.003,10L).show(5)
labelednewsgroups.filter(newsgroups("topic").like("comp%")).sample(false,0.007,10L).show(5)

+-----+--------------------+--------------------+-----+
|   id|                text|               topic|label|
+-----+--------------------+--------------------+-----+
|55278|Xref: cantaloupe....|  talk.politics.guns|  0.0|
|38839|Newsgroups: comp....|       comp.graphics|  1.0|
|52403|Newsgroups: comp....|comp.sys.mac.hard...|  1.0|
|51996|Path: cantaloupe....|comp.sys.mac.hard...|  1.0|
|83717|Xref: cantaloupe....|  talk.religion.misc|  0.0|
+-----+--------------------+--------------------+-----+
only showing top 5 rows

+-----+--------------------+--------------------+-----+
|   id|                text|               topic|label|
+-----+--------------------+--------------------+-----+
| 9519|Newsgroups: comp....|comp.os.ms-window...|  1.0|
| 9814|Xref: cantaloupe....|comp.os.ms-window...|  1.0|
|39078|Xref: cantaloupe....|       comp.graphics|  1.0|
|38839|Newsgroups: comp....|       comp.graphics|  1.0|
|38244|Newsgroups: comp....|       comp.graphics|  1.0|
+-----+----------------

## Split data set into separate training (90%) and test (10%) data sets
#### Split documents from a list of (id, text, label) tuples

In [45]:
val Array(training, test) = labelednewsgroups.randomSplit(Array(0.9, 0.1), seed = 12345)

#### Show a count of the resulting training and test data sets

In [46]:
println("Total Document Count = " + labelednewsgroups.count())
println("Training Count = " + training.count() + ", " + training.count*100/(labelednewsgroups.count()).toDouble + "%")
println("Test Count = " + test.count() + ", " + test.count*100/(labelednewsgroups.count().toDouble) + "%")

Total Document Count = 2000
Training Count = 1777, 88.85%
Test Count = 223, 11.15%


## Configure an ML Pipeline
In machine learning, it is common to run a sequence of algorithms to process and learn from data. Spark ML represents such a workflow as a Pipeline, which consists of a sequence of PipelineStages (Transformers and Estimators) to be run in a specific order. The pipeline we are using in this example consists of five stages: Tokenizer, StopWordsRemover, HashingTF, Inverse Document Frequency (IDF) and LogisticRegression.

**Tokenizer** splits the raw text documents into words, adding a new column with words into the dataset.

**StopWordsRemover** takes as input a sequence of strings and drops all the stop words from the input sequences. Stop words are words which should be excluded from the input, typically because the words appear frequently and don’t carry as much meaning. A list of stop words by default. Optionally you can provide a list of stopwords. We will just use the defualt list of stopwords.

**HashingTF** takes sets of terms and converts those sets into fixed-length feature vectors. 

**Inverse Document Frequency (IDF)** is a numerical measure of how much information a term provides. If a term appears very often across the corpus, it means it doesn’t carry special information about a particular document. IDF down-weights terms which appear frequently in a corpus.

**LogisticRegression** is a method used to predict a binary response. The current implementation of logistic regression in spark.ml only supports binary classes. Support for multiclass regression will be added in the future.

In [47]:
val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
val remover = new StopWordsRemover().setInputCol("words").setOutputCol("filtered").setCaseSensitive(false)
val hashingTF = new HashingTF().setNumFeatures(1000).setInputCol("filtered").setOutputCol("rawFeatures")
val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features").setMinDocFreq(0)
val lr = new LogisticRegression().setRegParam(0.01).setThreshold(0.5)
val pipeline = new Pipeline().setStages(Array(tokenizer, remover, hashingTF, idf, lr))

#### Show Logistic Regression Parameters

In [56]:
println("Logistic Regression Features Column = " + lr.getFeaturesCol)
println("Logistic Regression Label Column = " + lr.getLabelCol)
println("Logistic Regression Threshold = " + lr.getThreshold)

Logistic Regression Features Column = features
Logistic Regression Label Column = label
Logistic Regression Threshold = 0.5


#### Examine the parameters associated with each stage in the pipeline

In [57]:
println("Tokenizer:")
println(tokenizer.explainParams())
println("*************************")
println("Remover:")
println(remover.explainParams())
println("*************************")
println("HashingTF:")
println(hashingTF.explainParams())
println("*************************")
println("IDF:")
println(idf.explainParams())
println("*************************")
println("LogisticRegression:")
println(lr.explainParams())
println("*************************")
println("Pipeline:")
println(pipeline.explainParams())

Tokenizer:
inputCol: input column name (current: text)
outputCol: output column name (default: tok_fe57090347cf__output, current: words)
*************************
Remover:
caseSensitive: whether to do case-sensitive comparison during filtering (default: false, current: false)
inputCol: input column name (current: words)
outputCol: output column name (default: stopWords_4ce720b2751c__output, current: filtered)
stopWords: stop words (default: [Ljava.lang.String;@6dafc8db)
*************************
HashingTF:
inputCol: input column name (current: filtered)
numFeatures: number of features (> 0) (default: 262144, current: 1000)
outputCol: output column name (default: hashingTF_8562bb6678ac__output, current: rawFeatures)
*************************
IDF:
inputCol: input column name (current: rawFeatures)
minDocFreq: minimum of documents in which a term should appear for filtering (default: 0, current: 0)
outputCol: output column name (default: idf_9f6372dcf662__output, current: features)
******

#### Examine the list of default Stop Words that were applied in the pipeline
Stop Words are words which should be excluded from the input, typically because the words appear frequently and don’t carry as much meaning. You may also optionally provide your own list of Stop Words.

In [58]:
remover.getStopWords.foreach(println)

a
about
above
across
after
afterwards
again
against
all
almost
alone
along
already
also
although
always
am
among
amongst
amoungst
amount
an
and
another
any
anyhow
anyone
anything
anyway
anywhere
are
around
as
at
back
be
became
because
become
becomes
becoming
been
before
beforehand
behind
being
below
beside
besides
between
beyond
bill
both
bottom
but
by
call
can
cannot
cant
co
con
could
couldnt
cry
de
describe
detail
do
done
down
due
during
each
eg
eight
either
eleven
else
elsewhere
empty
enough
etc
even
ever
every
everyone
everything
everywhere
except
few
fifteen
fify
fill
find
fire
first
five
for
former
formerly
forty
found
four
from
front
full
further
get
give
go
had
has
hasnt
have
he
hence
her
here
hereafter
hereby
herein
hereupon
hers
herself
him
himself
his
how
however
hundred
i
ie
if
in
inc
indeed
interest
into
is
it
its
itself
keep
last
latter
latterly
least
less
ltd
made
many
may
me
meanwhile
might
mill
mine
more
moreover
most
mostly
move
much
must
my
myself
name
namely
neither

## Fit the pipeline to the training documents

In [59]:
val model = pipeline.fit(training)

### Make predictions on document in the Test data set
#### Keep in mind that the model has not seen the documents in the test data set.

#### Run the model against the test data set

In [61]:
val predictions = model.transform(test)

#### Show results using DataFrame API

In [63]:
predictions.select("id", "topic", "probability", "prediction", "label").sample(false,0.01,10L).show(5)
predictions.select("id", "topic", "probability", "prediction", "label").filter(predictions("topic").like("comp%")).sample(false,0.1,10L).show(5)

+-----+--------------------+--------------------+----------+-----+
|   id|               topic|         probability|prediction|label|
+-----+--------------------+--------------------+----------+-----+
|54446|  talk.politics.guns|[0.85907386640572...|       0.0|  0.0|
|53911|     sci.electronics|[0.99375677040263...|       0.0|  0.0|
|67516|      comp.windows.x|[0.78541867608477...|       0.0|  1.0|
|51539|comp.sys.mac.hard...|[0.99035383473721...|       0.0|  1.0|
+-----+--------------------+--------------------+----------+-----+

+-----+--------------------+--------------------+----------+-----+
|   id|               topic|         probability|prediction|label|
+-----+--------------------+--------------------+----------+-----+
| 9151|comp.os.ms-window...|[0.99677353549827...|       0.0|  1.0|
|38942|       comp.graphics|[0.41936760457692...|       1.0|  1.0|
|51613|comp.sys.mac.hard...|[0.07059632616872...|       1.0|  1.0|
|60199|comp.sys.ibm.pc.h...|[0.01268251597292...|       1.0| 

#### Show all the fields in order to see the results of each stage in the pipeline

In [66]:
predictions.sample(false,0.01,10L).show(5)

+-----+--------------------+--------------------+-----+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------+
|   id|                text|               topic|label|               words|            filtered|         rawFeatures|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+-----+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------+
|54446|Path: cantaloupe....|  talk.politics.guns|  0.0|[path:, cantaloup...|[path:, cantaloup...|(1000,[0,2,17,32,...|(1000,[0,2,17,32,...|[1.80761903155666...|[0.85907386640572...|       0.0|
|53911|Newsgroups: sci.e...|     sci.electronics|  0.0|[newsgroups:, sci...|[newsgroups:, sci...|(1000,[0,48,49,52...|(1000,[0,48,49,52...|[5.06999486675751...|[0.99375677040263...|       0.0|
|67516|Path: cantaloupe....|      c

## Create an evaluator for the binary classification
In this example we will use area under the ROC curve as the evaluation metric. Receiver operating characteristic (ROC) is a graphical plot that illustrates the performance of a binary classifier system as its discrimination threshold is varied. The curve is created by plotting the true positive rate against the false positive rate at various threshold settings. The ROC curve is thus the sensitivity as a function of fall-out. The area under the ROC curve is useful for comparing and selecting the best machine learning model for a given data set. A model with an area under the ROC curve score near 1 has very good performance. A model with a score near 0.5 is about as good as flipping a coin.

In [67]:
val evaluator = new BinaryClassificationEvaluator().setMetricName("areaUnderROC")
println("Area under the ROC curve = " + evaluator.evaluate(predictions))

Area under the ROC curve = 0.8699226305609282


## Tune Hyperparameters
#### Generate hyperparameter combinations by taking the cross product of some parameter values

Spark MLlib algorithms provide many hyperparameters for tuning models. These hyperparameters are distinct from the model parameters being optimized by MLlib itself. Hyperparameter tuning is accomplished by choosing the best set of parameters based on model performance on test data that the model was not trained with. All combinations of hyperparameters specified will be tried in order to find the one that leads to the model with the best evaluation result.

#### Build a Parameter Grid specifying what parameters and values will be evaluated in order to determine the best combination

In [69]:
val paramGrid = new ParamGridBuilder().
addGrid(hashingTF.numFeatures, Array(1000, 10000, 100000)).
addGrid(idf.minDocFreq, Array(0,10, 100)).
build()

## Create a cross validator to tune the pipeline with the generated parameter grid
Spark MLlib provides for cross-validation for hyperparameter tuning. Cross-validation attempts to fit the underlying estimator with user-specified combinations of parameters, cross-evaluate the fitted models, and output the best one.

In [70]:
val cv = new CrossValidator().setEstimator(pipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(2)

## Cross-evaluate the ML Pipeline to find the best model
using the area under the ROC evaluator and hyperparameters specified in the parameter grid

In [71]:
val cvModel = cv.fit(training)
println("Area under the ROC curve for best fitted model = " + evaluator.evaluate(cvModel.transform(test)))

Area under the ROC curve for best fitted model = 0.9804158607350093


#### Let's see what improvement we achieve by tuning the hyperparameters using cross-evaluation 

In [72]:
println("Area under the ROC curve for non-tuned model = " + evaluator.evaluate(predictions))
println("Area under the ROC curve for fitted model = " + evaluator.evaluate(cvModel.transform(test)))
println("Improvement = " + "%.2f".format((evaluator.evaluate(cvModel.transform(test)) - evaluator.evaluate(predictions)) *100 / evaluator.evaluate(predictions)) + "%")

Area under the ROC curve for non-tuned model = 0.8699226305609284
Area under the ROC curve for fitted model = 0.9804158607350093
Improvement = 12.70%


### Make improved predictions on documents using the Cross-validated model
#### using the Test data set
Using DataFrame API

In [73]:
cvModel.transform(test).select("id", "topic", "probability", "prediction", "label").sample(false,0.01,0L).show(5)
cvModel.transform(test).select("id", "topic", "probability", "prediction", "label").filter(predictions("topic").like("comp%")).sample(false,0.1,0L).show(5)

+-----+--------------------+--------------------+----------+-----+
|   id|               topic|         probability|prediction|label|
+-----+--------------------+--------------------+----------+-----+
|54724|    rec.sport.hockey|[0.98798227442083...|       0.0|  0.0|
|61180|           sci.space|[0.99048970449598...|       0.0|  0.0|
|67320|      comp.windows.x|[0.32190183491300...|       1.0|  1.0|
|52300|comp.sys.mac.hard...|[0.03445720977205...|       1.0|  1.0|
|83827|  talk.religion.misc|[0.95191178509327...|       0.0|  0.0|
+-----+--------------------+--------------------+----------+-----+

+-----+--------------------+--------------------+----------+-----+
|   id|               topic|         probability|prediction|label|
+-----+--------------------+--------------------+----------+-----+
|38379|       comp.graphics|[0.90242609785147...|       0.0|  1.0|
|66413|      comp.windows.x|[0.18273807588790...|       1.0|  1.0|
|66978|      comp.windows.x|[0.06678739707220...|       1.0| 

## Conclusion
#### This analysis was intended to illustrate how to use the spark.ml machine learning package utilizing a machine learning pipeline. Although a document classification use case was specifically demonstrated, many of the principles demonstrated in the notebook can be employed to other machine learning use cases. Obviously the algorithms that need to be employed will be dependent on the specific use case.

#### Also please note that although this demo illustrated how to tune a model for best fit, no attempt was made to actually optimize to the best possible model. The intent was simply to show the methodology.