## SPAM CLASSIFIER

### Downloading the dataset

In [1]:
import os

DATASETS_DIR = 'datasets'
MODELS_DIR = 'models'
TAR_DIR = os.path.join(DATASETS_DIR, 'tar')

SPAM_URL = 'https://spamassassin.apache.org/old/publiccorpus/20050311_spam_2.tar.bz2'
EASY_HAM_URL = 'https://spamassassin.apache.org/old/publiccorpus/20030228_easy_ham_2.tar.bz2'
HARD_HAM_URL = 'https://spamassassin.apache.org/old/publiccorpus/20030228_hard_ham.tar.bz2'

In [2]:
from urllib.request import urlretrieve
import tarfile
import shutil

def download_dataset(url):
    """download and unzip data from a url into the specified path"""
    
    # create directory if it doesn't exist
    if not os.path.isdir(TAR_DIR):
        os.makedirs(TAR_DIR)
    
    filename = url.rsplit('/', 1)[-1]
    tarpath = os.path.join(TAR_DIR, filename)
    
    # download the tar file if it doesn't exist
    try:
        tarfile.open(tarpath)
    except:
        urlretrieve(url, tarpath)
    
    with tarfile.open(tarpath) as tar:
        dirname = os.path.join(DATASETS_DIR, tar.getnames()[0])
        if os.path.isdir(dirname):
            shutil.rmtree(dirname)
        tar.extractall(path=DATASETS_DIR)
        
        cmds_path = os.path.join(dirname, 'cmds')
        if os.path.isfile(cmds_path):
            os.remove(cmds_path)
    
    return dirname

In [3]:
spam_dir = download_dataset(SPAM_URL)
easy_ham_dir = download_dataset(EASY_HAM_URL)
hard_ham_dir = download_dataset(HARD_HAM_URL)

In [4]:
easy_ham_filenames = [name for name in sorted(os.listdir(easy_ham_dir)) if len(name) > 20]
hard_ham_filenames = [name for name in sorted(os.listdir(hard_ham_dir)) if len(name) > 20]
spam_filenames = [name for name in sorted(os.listdir(spam_dir)) if len(name) > 20]

In [5]:
len(easy_ham_filenames)

1400

In [6]:
len(hard_ham_filenames)

250

In [7]:
len(spam_filenames)

1396

In [8]:
spam_dir

'datasets/spam_2'

### Parsing the emails

In [9]:
import email
import email.policy

def load_email(is_spam, filename, spam_path='datasets'):
    directory = "spam_2" if is_spam else "easy_ham_2"
    with open(os.path.join(spam_path, directory, filename), "rb") as f:
        return email.parser.BytesParser(policy=email.policy.default).parse(f)

In [10]:
easy_ham_emails = [load_email(is_spam=False, filename=name) for name in easy_ham_filenames]
spam_emails = [load_email(is_spam=True, filename=name) for name in spam_filenames]

### Take a look at some examples

In [11]:
easy_ham_emails[0].get_content().strip()

'Date:        Tue, 20 Aug 2002 17:27:47 -0500\n    From:        Chris Garrigues <cwg-exmh@DeepEddy.Com>\n    Message-ID:  <1029882468.3116.TMDA@deepeddy.vircio.com>\n\n\n  | I\'m hoping that all people with no additional sequences will notice are\n  | purely cosmetic changes.\n\nWell, first, when exmh (the latest one with your changes) starts, I get...\n\ncan\'t read "flist(totalcount,unseen)": no such element in array\n    while executing\n"if {$flist(totalcount,$mhProfile(unseen-sequence)) > 0} {\n\tFlagInner spool iconspool labelup\n    } else {\n\tFlagInner down icondown labeldown\n    }"\n    (procedure "Flag_MsgSeen" line 3)\n    invoked from within\n"Flag_MsgSeen"\n    (procedure "MsgSeen" line 8)\n    invoked from within\n"MsgSeen $msgid"\n    (procedure "MsgShow" line 12)\n    invoked from within\n"MsgShow $msgid"\n    (procedure "MsgChange" line 17)\n    invoked from within\n"MsgChange 4862 show"\n    invoked from within\n"time [list MsgChange $msgid $show"\n    (procedure "M

In [12]:
spam_emails[10].get_content().strip()

'Yes we do purchase uncollected Judicial Judgements!!!            st10                           .           \n\nIf you, your company or an acquaintance have an uncollected Judicial Judgement then please call us and find out how we can help you receive the money that the court states you are rightfully due.\n\nWe have strong interest in acquiring uncollected Judicial Judgements in your City and Area.\n\nJ T C is the largest firm in the world specializing in the purchase and collection of Judicial Judgements.\n\nCurrently we are processing over 455 million dollars worth of judgements in the United States alone. We have associate offices in virtually every city in the US and in most foreign countries.\n\nYou have nothing to lose and everything to gain by calling. There is absolutely no cost to you.\n\nWe can be reached Toll free at 1-888-557-5744. in the US or if you are in Canada call 1-310-842-3521. You can call 24 hours per day.\n\nThank you for your time.\n\n\n\n\n\n\n\n+++++++++++++

### Exploring email structures

In [13]:
def get_email_structure(email):
    if isinstance(email, str):
        return email
    payload = email.get_payload()
    if isinstance(payload, list):
        return "multipart({})".format(", ".join([
            get_email_structure(sub_email)
            for sub_email in payload
        ]))
    else:
        return email.get_content_type()

In [14]:
from collections import Counter
def structures_counter(emails):
    structures = Counter()
    for email in emails:
        structure = get_email_structure(email)
        structures[structure] += 1
    return structures

In [15]:
structures_counter(easy_ham_emails).most_common()

[('text/plain', 1343),
 ('multipart(text/plain, application/pgp-signature)', 35),
 ('multipart(text/plain, text/html)', 12),
 ('text/html', 2),
 ('multipart(text/plain, application/x-patch)', 1),
 ('multipart(multipart(text/plain, multipart(text/plain), text/plain), application/pgp-signature)',
  1),
 ('multipart(text/plain, multipart(text/plain))', 1),
 ('multipart(multipart(text/plain, text/html), image/jpeg, image/gif, image/gif, image/gif, image/gif)',
  1),
 ('multipart(text/plain, application/ms-tnef)', 1),
 ('multipart(text/plain, text/plain, text/plain)', 1),
 ('multipart(text/plain, multipart(text/plain, text/plain), text/rfc822-headers)',
  1),
 ('multipart(text/plain, application/ms-tnef, text/plain)', 1)]

In [16]:
structures_counter(spam_emails).most_common()

[('text/plain', 597),
 ('text/html', 589),
 ('multipart(text/plain, text/html)', 114),
 ('multipart(text/html)', 29),
 ('multipart(text/plain)', 25),
 ('multipart(multipart(text/html))', 18),
 ('multipart(multipart(text/plain, text/html))', 5),
 ('multipart(text/plain, application/octet-stream, text/plain)', 3),
 ('multipart(text/html, text/plain)', 2),
 ('multipart(text/html, image/jpeg)', 2),
 ('multipart(multipart(text/plain), application/octet-stream)', 2),
 ('multipart(text/plain, application/octet-stream)', 2),
 ('multipart(text/plain, multipart(text/plain))', 1),
 ('multipart(multipart(text/plain, text/html), image/jpeg, image/jpeg, image/jpeg, image/jpeg, image/jpeg)',
  1),
 ('multipart(multipart(text/plain, text/html), image/jpeg, image/jpeg, image/jpeg, image/jpeg, image/gif)',
  1),
 ('text/plain charset=us-ascii', 1),
 ('multipart(multipart(text/html), image/gif)', 1),
 ('multipart(multipart(text/plain, text/html), application/octet-stream, application/octet-stream, applic

In [17]:
for header, value in spam_emails[0].items():
    print(header,":",value)

Return-Path : <ilug-admin@linux.ie>
Delivered-To : yyyy@localhost.netnoteinc.com
Received : from localhost (localhost [127.0.0.1])	by phobos.labs.netnoteinc.com (Postfix) with ESMTP id 9E1F5441DD	for <jm@localhost>; Tue,  6 Aug 2002 06:48:09 -0400 (EDT)
Received : from phobos [127.0.0.1]	by localhost with IMAP (fetchmail-5.9.0)	for jm@localhost (single-drop); Tue, 06 Aug 2002 11:48:09 +0100 (IST)
Received : from lugh.tuatha.org (root@lugh.tuatha.org [194.125.145.45]) by    dogma.slashnull.org (8.11.6/8.11.6) with ESMTP id g72LqWv13294 for    <jm-ilug@jmason.org>; Fri, 2 Aug 2002 22:52:32 +0100
Received : from lugh (root@localhost [127.0.0.1]) by lugh.tuatha.org    (8.9.3/8.9.3) with ESMTP id WAA31224; Fri, 2 Aug 2002 22:50:17 +0100
Received : from bettyjagessar.com (w142.z064000057.nyc-ny.dsl.cnc.net    [64.0.57.142]) by lugh.tuatha.org (8.9.3/8.9.3) with ESMTP id WAA31201 for    <ilug@linux.ie>; Fri, 2 Aug 2002 22:50:11 +0100
Received : from 64.0.57.142 [202.63.165.34] by bettyjagessa

In [18]:
spam_emails[0]["Subject"]

'[ILUG] STOP THE MLM INSANITY'

### Split the data into train and test set

In [20]:
import numpy as np

In [21]:
from sklearn.model_selection import train_test_split

X = np.array(easy_ham_emails + spam_emails)
y = np.array([0] * len(easy_ham_emails) + [1] * len(spam_emails))

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

### Handling HTML files

In [22]:
import re
from html import unescape

def html_to_plain_text(html):
    text = re.sub('<head.*?>.*?</head>', '', html, flags=re.M | re.S | re.I)
    text = re.sub('<a\s.*?>', ' HYPERLINK ', text, flags=re.M | re.S | re.I)
    text = re.sub('<.*?>', '', text, flags=re.M | re.S)
    text = re.sub(r'(\s*\n)+', '\n', text, flags=re.M | re.S)
    return unescape(text)

In [23]:
html_spam_emails = [email for email in X_train[y_train==1]
                    if get_email_structure(email) == "text/html"]
sample_html_spam = html_spam_emails[7]
print(sample_html_spam.get_content().strip()[:1000], "...")

<HR>
<html>
<div bgcolor="#FFFFCC">

  <p align="center"><a
href="http://www.fabulousmail.com"><img border="0"
src="http://www.fabulousmail.com/Toners2goLogo.jpg"
width="349" height="96"></a></p>
<p align="center"><font size="6" face="Arial MT
Black"><i>Tremendous Savings</i>
on Toners,&nbsp;</font></p>
<p align="center"><font size="6" face="Arial MT
Black">
Inkjets, FAX, and Thermal Replenishables!!</font></p>
<p><a href="http://www.fabulousmail.com">Toners 2 Go
</a>is your secret
weapon to lowering your cost for <a
href="http://www.fabulousmail.com">High Quality,
Low-Cost</a> printer
supplies!&nbsp; We have been in the printer
replenishables business since 1992,
and pride ourselves on rapid response and outstanding
customer service.&nbsp;
What we sell are 100% compatible replacements for
Epson, Canon, Hewlett Packard,
Xerox, Okidata, Brother, and Lexmark; products that
meet and often exceed
original manufacturer's specifications.</p>
<p><i><font size="4">Check out these
prices!</font

In [24]:
print(html_to_plain_text(sample_html_spam.get_content())[:1000], "...")


   HYPERLINK
Tremendous Savings
on Toners, 
Inkjets, FAX, and Thermal Replenishables!!
 HYPERLINK Toners 2 Go
is your secret
weapon to lowering your cost for  HYPERLINK High Quality,
Low-Cost printer
supplies!  We have been in the printer
replenishables business since 1992,
and pride ourselves on rapid response and outstanding
customer service. 
What we sell are 100% compatible replacements for
Epson, Canon, Hewlett Packard,
Xerox, Okidata, Brother, and Lexmark; products that
meet and often exceed
original manufacturer's specifications.
Check out these
prices!
        Epson Stylus
Color inkjet cartridge
(SO20108):     Epson's Price:
$27.99    
Toners2Go price: $9.95!
         HP
LaserJet 4 Toner Cartridge
(92298A):           
HP's
Price:
$88.99           
Toners2Go
  price: $41.75!
 
Come visit us on the web to check out our hundreds
of similar bargains at  HYPERLINK Toners
2 Go!
  request to be removed by clicking  HYPERLINK HERE
derekw
http://xent.com/mailman/listinfo/fork
 ...


In [25]:
def email_to_text(email):
    html = None
    for part in email.walk():
        ctype = part.get_content_type()
        if not ctype in ("text/plain", "text/html"):
            continue
        try:
            content = part.get_content()
        except: # in case of encoding issues
            content = str(part.get_payload())
        if ctype == "text/plain":
            return content
        else:
            html = content
    if html:
        return html_to_plain_text(html)

In [26]:
print(email_to_text(sample_html_spam)[:100], "...")


   HYPERLINK
Tremendous Savings
on Toners, 
Inkjets, FAX, and Thermal Replenishables!!
 HYPERLINK T ...


### Install NLTK (http://www.nltk.org/)

In [27]:
pip install nltk

Collecting nltk
[?25l  Downloading https://files.pythonhosted.org/packages/92/75/ce35194d8e3022203cca0d2f896dbb88689f9b3fce8e9f9cff942913519d/nltk-3.5.zip (1.4MB)
[K     |████████████████████████████████| 1.4MB 6.7MB/s eta 0:00:01
[?25hCollecting click (from nltk)
[?25l  Downloading https://files.pythonhosted.org/packages/d2/3d/fa76db83bf75c4f8d338c2fd15c8d33fdd7ad23a9b5e57eb6c5de26b430e/click-7.1.2-py2.py3-none-any.whl (82kB)
[K     |████████████████████████████████| 92kB 18.5MB/s eta 0:00:01
[?25hCollecting joblib (from nltk)
[?25l  Downloading https://files.pythonhosted.org/packages/b8/a6/d1a816b89aa1e9e96bcb298eb1ee1854f21662ebc6d55ffa3d7b3b50122b/joblib-0.15.1-py3-none-any.whl (298kB)
[K     |████████████████████████████████| 307kB 8.3MB/s eta 0:00:01
[?25hCollecting regex (from nltk)
[?25l  Downloading https://files.pythonhosted.org/packages/60/7c/0d46b10a87b3087e8e303fac923beb19ec839d7c5ea34971a12fafb22b52/regex-2020.5.14-cp36-cp36m-manylinux2010_x86_64.whl (675kB)
[K

In [28]:
try:
    import nltk

    stemmer = nltk.PorterStemmer()
    for word in ("Computations", "Computation", "Computing", "Computed", "Compute", "Compulsive"):
        print(word, "=>", stemmer.stem(word))
except ImportError:
    print("Error: stemming requires the NLTK module.")
    stemmer = None

Computations => comput
Computation => comput
Computing => comput
Computed => comput
Compute => comput
Compulsive => compuls


In [30]:
try:
    import urlextract # may require an Internet connection to download root domain names
    
    url_extractor = urlextract.URLExtract()
    print(url_extractor.find_urls("Will it detect github.com and https://youtu.be/7Pq-S557XQU?t=3m32s"))
except ImportError:
    print("Error: replacing URLs requires the urlextract module.")
    url_extractor = None

Error: replacing URLs requires the urlextract module.


### Convert emails into word counters

In [31]:
from sklearn.base import BaseEstimator, TransformerMixin

class EmailToWordCounterTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, strip_headers=True, lower_case=True, remove_punctuation=True,
                 replace_urls=True, replace_numbers=True, stemming=True):
        self.strip_headers = strip_headers
        self.lower_case = lower_case
        self.remove_punctuation = remove_punctuation
        self.replace_urls = replace_urls
        self.replace_numbers = replace_numbers
        self.stemming = stemming
    def fit(self, X, y=None):
        return self
    def transform(self, X, y=None):
        X_transformed = []
        for email in X:
            text = email_to_text(email) or ""
            if self.lower_case:
                text = text.lower()
            if self.replace_urls and url_extractor is not None:
                urls = list(set(url_extractor.find_urls(text)))
                urls.sort(key=lambda url: len(url), reverse=True)
                for url in urls:
                    text = text.replace(url, " URL ")
            if self.replace_numbers:
                text = re.sub(r'\d+(?:\.\d*(?:[eE]\d+))?', 'NUMBER', text)
            if self.remove_punctuation:
                text = re.sub(r'\W+', ' ', text, flags=re.M)
            word_counts = Counter(text.split())
            if self.stemming and stemmer is not None:
                stemmed_word_counts = Counter()
                for word, count in word_counts.items():
                    stemmed_word = stemmer.stem(word)
                    stemmed_word_counts[stemmed_word] += count
                word_counts = stemmed_word_counts
            X_transformed.append(word_counts)
        return np.array(X_transformed)

In [32]:
X_few = X_train[:3]
X_few_wordcounts = EmailToWordCounterTransformer().fit_transform(X_few)
X_few_wordcounts

array([Counter({'number': 8, 'to': 5, 'mail': 4, 'thi': 3, 'http': 3, 'www': 3, 'year': 2, 'annuiti': 2, 'or': 2, 'e': 2, 'insur': 2, 'com': 2, 'not': 2, 'profession': 2, 'insurancemail': 2, 'net': 2, 'legal': 2, 'holi': 1, 'cow': 1, 'guarante': 1, 'rate': 1, 'commiss': 1, 'surrend': 1, 'limit': 1, 'time': 1, 'onli': 1, 'call': 1, 'mailto': 1, 'safe': 1, 'us': 1, 'today': 1, 'pleas': 1, 'fill': 1, 'out': 1, 'the': 1, 'form': 1, 'below': 1, 'for': 1, 'more': 1, 'inform': 1, 'name': 1, 'phone': 1, 'citi': 1, 'state': 1, 'we': 1, 'don': 1, 't': 1, 'want': 1, 'anyon': 1, 'receiv': 1, 'our': 1, 'who': 1, 'doe': 1, 'wish': 1, 'is': 1, 'commun': 1, 'sent': 1, 'be': 1, 'remov': 1, 'from': 1, 'list': 1, 'do': 1, 'repli': 1, 'messag': 1, 'instead': 1, 'go': 1, 'here': 1, 'notic': 1, 'insiq': 1, 'htm': 1}),
       Counter({'to': 17, 'you': 16, 'a': 9, 'number': 9, 'we': 8, 'for': 7, 'be': 7, 'your': 7, 'thi': 6, 'stori': 5, 'are': 5, 'if': 5, 'us': 5, 'look': 4, 'tell': 4, 'i': 3, 'am': 3, 'tv': 

### Convert word counts to vectors

In [33]:
from scipy.sparse import csr_matrix

class WordCounterToVectorTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, vocabulary_size=1000):
        self.vocabulary_size = vocabulary_size
    def fit(self, X, y=None):
        total_count = Counter()
        for word_count in X:
            for word, count in word_count.items():
                total_count[word] += min(count, 10)
        most_common = total_count.most_common()[:self.vocabulary_size]
        self.most_common_ = most_common
        self.vocabulary_ = {word: index + 1 for index, (word, count) in enumerate(most_common)}
        return self
    def transform(self, X, y=None):
        rows = []
        cols = []
        data = []
        for row, word_count in enumerate(X):
            for word, count in word_count.items():
                rows.append(row)
                cols.append(self.vocabulary_.get(word, 0))
                data.append(count)
        return csr_matrix((data, (rows, cols)), shape=(len(X), self.vocabulary_size + 1))

In [34]:
vocab_transformer = WordCounterToVectorTransformer(vocabulary_size=10)
X_few_vectors = vocab_transformer.fit_transform(X_few_wordcounts)
X_few_vectors

<3x11 sparse matrix of type '<class 'numpy.longlong'>'
	with 29 stored elements in Compressed Sparse Row format>

In [35]:
X_few_vectors.toarray()

array([[ 81,   5,   8,   0,   0,   1,   3,   1,   1,   1,   0],
       [189,  17,   9,   9,  16,   3,   6,   8,   7,   7,   2],
       [290,   9,   4,   8,   7,  15,   2,   1,   1,   0,   6]],
      dtype=int64)

In [36]:
vocab_transformer.vocabulary_

{'to': 1,
 'number': 2,
 'a': 3,
 'you': 4,
 'the': 5,
 'thi': 6,
 'we': 7,
 'for': 8,
 'be': 9,
 'of': 10}

## Training models

In [37]:
from sklearn.pipeline import Pipeline

preprocess_pipeline = Pipeline([
    ("email_to_wordcount", EmailToWordCounterTransformer()),
    ("wordcount_to_vector", WordCounterToVectorTransformer()),
])

X_train_transformed = preprocess_pipeline.fit_transform(X_train)

In [47]:
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold, cross_val_score, cross_val_predict
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, precision_recall_curve, roc_auc_score
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC, LinearSVC
from sklearn.neighbors import KNeighborsClassifier

### Logistic Regression

In [39]:
log_clf = LogisticRegression(solver="lbfgs", random_state=42)
score = cross_val_score(log_clf, X_train_transformed, y_train, verbose=3)
score.mean()

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


[CV]  ................................................................


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.3s remaining:    0.0s


[CV] ....................... , score=0.9865951742627346, total=   0.3s
[CV]  ................................................................


[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.8s remaining:    0.0s


[CV] ........................ , score=0.985254691689008, total=   0.4s
[CV]  ................................................................
[CV] ....................... , score=0.9865591397849462, total=   0.5s


[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    1.3s finished


0.986136335245563

### Support Vector Classifier

In [40]:
svc_clf = SVC(gamma='auto')
score = cross_val_score(svc_clf, X_train_transformed, y_train, verbose=3)
score.mean()

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


[CV]  ................................................................
[CV] ....................... , score=0.9463806970509383, total=   2.9s
[CV]  ................................................................


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    2.9s remaining:    0.0s


[CV] ....................... , score=0.9369973190348525, total=   3.2s
[CV]  ................................................................


[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    6.1s remaining:    0.0s


[CV] ....................... , score=0.9543010752688172, total=   3.0s


[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    9.1s finished


0.9458930304515359

In [41]:
linear_svc = LinearSVC()
score = cross_val_score(linear_svc, X_train_transformed, y_train, verbose=3)
score.mean()

[CV]  ................................................................


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.2s remaining:    0.0s


[CV] ........................ , score=0.985254691689008, total=   0.2s
[CV]  ................................................................
[CV] ....................... , score=0.9798927613941019, total=   0.2s
[CV]  ................................................................


[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.5s remaining:    0.0s


[CV] ....................... , score=0.9690860215053764, total=   0.3s


[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.7s finished


0.9780778248628287

### Decision Trees

In [48]:
tree_clf = DecisionTreeClassifier(max_depth=5)
score = cross_val_score(tree_clf, X_train_transformed, y_train, verbose=3)
score.mean()

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.1s remaining:    0.0s


[CV]  ................................................................
[CV] ........................ , score=0.935656836461126, total=   0.1s
[CV]  ................................................................
[CV] ....................... , score=0.9436997319034852, total=   0.1s
[CV]  ................................................................
[CV] ....................... , score=0.9489247311827957, total=   0.1s


[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.3s finished


0.9427604331824689

RandomForests

In [51]:
rnd_clf = RandomForestClassifier(n_estimators=500, max_leaf_nodes=16, n_jobs=-1)
score = cross_val_score(rnd_clf, X_train_transformed, y_train)
score.mean()



0.9637733383301142

In [42]:
def prediction(model):
    X_test_transformed = preprocess_pipeline.transform(X_test)
    model.fit(X_train_transformed, y_train)
    y_pred = model.predict(X_test_transformed)
    print(confusion_matrix(y_test,y_pred))
    print("Precision: {:.2f}%".format(100 * precision_score(y_test,y_pred)))
    print("Recall: {:.2f}%".format(100 * recall_score(y_test,y_pred)))
    print("F1: {:.2f}%".format(100 * f1_score(y_test,y_pred)))

In [43]:
prediction(log_clf)

[[269   4]
 [  3 284]]
Precision: 98.61%
Recall: 98.95%
F1: 98.78%




In [44]:
prediction(svc_clf)

[[260  13]
 [ 13 274]]
Precision: 95.47%
Recall: 95.47%
F1: 95.47%


In [45]:
prediction(linear_svc)

[[269   4]
 [  9 278]]
Precision: 98.58%
Recall: 96.86%
F1: 97.72%




In [49]:
prediction(tree_clf)

[[257  16]
 [ 12 275]]
Precision: 94.50%
Recall: 95.82%
F1: 95.16%


In [52]:
prediction(rnd_clf)

[[264   9]
 [  5 282]]
Precision: 96.91%
Recall: 98.26%
F1: 97.58%
