<a href="https://colab.research.google.com/github/rklepov/hse-cs-ml-2018-2019/blob/master/08-spark/03-ml/toyml.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip search spark | grep INSTALLED || pip install pyspark==2.4.0 findspark

  INSTALLED: 2.4.0


In [0]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64/jre/"

import findspark
findspark.init('/usr/local/lib/python3.6/dist-packages/pyspark/')

import pyspark

In [0]:
from zipfile import ZipFile
from io import BytesIO
import urllib.request

import ssl

ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE

def download(url):
    ZipFile.extractall(
        ZipFile(
            BytesIO(
                urllib
                .request
                .urlopen(url,context=ctx)
                .read()
            )
        ),
    )

In [0]:
download('https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip')

In [5]:
!head readme

SMS Spam Collection v.1
-------------------------

1. DESCRIPTION
--------------

The SMS Spam Collection v.1 (hereafter the corpus) is a set of SMS tagged messages that have been collected for SMS Spam research. It contains one set of SMS messages in English of 5,574 messages, tagged acording being ham (legitimate) or spam. 

1.1. Compilation
----------------


In [0]:
spark = pyspark.sql.SparkSession.builder.getOrCreate()

In [0]:
sms = spark.read.option('sep', '\t').csv('SMSSpamCollection')

In [0]:
src = sms.withColumnRenamed('_c0', 'label').withColumnRenamed('_c1', 'text')

In [9]:
src.groupBy('label').count().show()

+-----+-----+
|label|count|
+-----+-----+
|  ham| 4827|
| spam|  747|
+-----+-----+



In [0]:
from pyspark.sql import functions

from pyspark.ml import feature
from pyspark.ml import pipeline
from pyspark.ml import classification
from pyspark.ml import evaluation

In [11]:
feature.Tokenizer(inputCol='text', outputCol='token').transform(src).show()

+-----+--------------------+--------------------+
|label|                text|               token|
+-----+--------------------+--------------------+
|  ham|Go until jurong p...|[go, until, juron...|
|  ham|Ok lar... Joking ...|[ok, lar..., joki...|
| spam|Free entry in 2 a...|[free, entry, in,...|
|  ham|U dun say so earl...|[u, dun, say, so,...|
|  ham|Nah I don't think...|[nah, i, don't, t...|
| spam|FreeMsg Hey there...|[freemsg, hey, th...|
|  ham|Even my brother i...|[even, my, brothe...|
|  ham|As per your reque...|[as, per, your, r...|
| spam|WINNER!! As a val...|[winner!!, as, a,...|
| spam|Had your mobile 1...|[had, your, mobil...|
|  ham|I'm gonna be home...|[i'm, gonna, be, ...|
| spam|SIX chances to wi...|[six, chances, to...|
| spam|URGENT! You have ...|[urgent!, you, ha...|
|  ham|I've been searchi...|[i've, been, sear...|
|  ham|I HAVE A DATE ON ...|[i, have, a, date...|
| spam|XXXMobileMovieClu...|[xxxmobilemoviecl...|
|  ham|Oh k...i'm watchi...|[oh, k...i'm, wat...|


In [0]:
main = pipeline.Pipeline(
    stages =(
        feature.RegexTokenizer(
            minTokenLength=3,
            inputCol='text',
            pattern=r'\s+',
            outputCol='tokens'
        ),
        feature.CountVectorizer(
            inputCol='tokens',
            outputCol='v',
            minDF=5,
            maxDF=900
        ),
        feature.StringIndexer(
            inputCol='label',
            outputCol='y'
        ),
        classification.RandomForestClassifier(
            labelCol='y',
            featuresCol='v',
            seed=123
        )
    )
)

In [13]:
train, test = src.randomSplit(weights=(70., 30.), seed=123)

main_model = main.fit(train)

results = (
    main_model
    .transform(test)
    .select(
        'y', 
        'rawPrediction', 
        'probability', 
        'prediction'
    )
)

results.cache()

results.show(5)

+---+--------------------+--------------------+----------+
|  y|       rawPrediction|         probability|prediction|
+---+--------------------+--------------------+----------+
|0.0|[17.4577708451969...|[0.87288854225984...|       0.0|
|0.0|[17.9915243908026...|[0.89957621954013...|       0.0|
|0.0|[17.9915243908026...|[0.89957621954013...|       0.0|
|0.0|[17.9915243908026...|[0.89957621954013...|       0.0|
|0.0|[17.9915243908026...|[0.89957621954013...|       0.0|
+---+--------------------+--------------------+----------+
only showing top 5 rows



In [14]:
results.orderBy('probability').show(5)

+---+--------------------+--------------------+----------+
|  y|       rawPrediction|         probability|prediction|
+---+--------------------+--------------------+----------+
|1.0|[6.02774777924674...|[0.30138738896233...|       1.0|
|1.0|[7.05938187799983...|[0.35296909389999...|       1.0|
|1.0|[7.06069250555023...|[0.35303462527751...|       1.0|
|1.0|[7.28622884676247...|[0.36431144233812...|       1.0|
|1.0|[7.54685095248554...|[0.37734254762427...|       1.0|
+---+--------------------+--------------------+----------+
only showing top 5 rows



In [15]:
results.orderBy(functions.desc('probability')).show(5)

+---+--------------------+--------------------+----------+
|  y|       rawPrediction|         probability|prediction|
+---+--------------------+--------------------+----------+
|0.0|[18.0893416907218...|[0.90446708453609...|       0.0|
|0.0|[18.0893416907218...|[0.90446708453609...|       0.0|
|1.0|[18.0893416907218...|[0.90446708453609...|       0.0|
|1.0|[18.0893416907218...|[0.90446708453609...|       0.0|
|1.0|[18.0893416907218...|[0.90446708453609...|       0.0|
+---+--------------------+--------------------+----------+
only showing top 5 rows



In [16]:
evaluation.BinaryClassificationEvaluator(labelCol='y').evaluate(results)

0.9314047828132821