<a href="https://colab.research.google.com/github/rajeevrpandey/Spam-Classifier/blob/main/Spam_Classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tarfile
from pathlib import Path
import urllib.request

def fetch_spam_data():
  spam_root = "http://spamassassin.apache.org/old/publiccorpus/"
  ham_url = spam_root + "20030228_easy_ham.tar.bz2"
  spam_url = spam_root + "20030228_spam.tar.bz2"

  # creates a path for the spam datasets using the Path object from pathlib
  spam_path = Path() / "datasets" / "spam"  # creating multiple directories
  spam_path.mkdir(parents=True, exist_ok=True)  # mkdir method is then used to
  #create the directory if it doesn't exist, along with any necessary parent directories (parents=True)


  for dir_name, tar_name, url in (("easy_ham", "ham", ham_url),
                                  ("spam", "spam", spam_url)):
      if not (spam_path / dir_name).is_dir():
          #The is_dir() method returns True if the path corresponds to an existing directory, and False otherwise.
          path = (spam_path / tar_name).with_suffix(".tar.bz2") #This method replaces the existing suffix of the file, if any, with the specified suffix (in this case, ".tar.bz2")
          print("Downloading", path)
          urllib.request.urlretrieve(url, path)
          '''
          urllib is a package that collects several modules for working with URLs:
            urllib.request for opening and reading URLs
            urllib.error containing the exceptions raised by urllib.request
            urllib.parse for parsing URLs
            urllib.robotparser for parsing robots.txt files
          '''
          tar_bz2_file = tarfile.open(path, "r:bz2")
          tar_bz2_file.extractall(path=spam_path)
          tar_bz2_file.close()
  return [spam_path / dir_name for dir_name in ("easy_ham", "spam")]
  # function returns a list of paths to the extracted directories for "easy_ham" and "spam" datasets.

In [2]:
ham_dir, spam_dir = fetch_spam_data()

In [3]:
ham_filenames = [f for f in sorted(ham_dir.iterdir()) if len(f.name) > 20]
#  "ham_filenames" list contains filenames from the "ham_dir" directory that have a length greater than 20 characters
spam_filenames = [f for f in sorted(spam_dir.iterdir()) if len(f.name) > 20]
# "spam_filenames" list contains filenames from the "spam_dir" directory that have a length greater than 20 characters

In [4]:
len(ham_filenames)

2500

In [5]:
len(spam_filenames)

500

In [6]:
# We can use Python's email module to parse these emails (this handles headers, encoding, and so on):
import email  # provides the functionality to parse, create, and manipulate email messages
import email.policy # provides policies for parsing and generating email messages

def parse_email(filepath): # filepath is path to the email file
  with open(filepath, "rb") as f:
    # 'r'-read mode, 'w'- write mode, 'a'- append mode, 'r+'- read and write mode, 'w+'- write and read mode, 'a+'- append and read mode, 'b'- binary mode, 't'- text mode
    return email.parser.BytesParser(policy=email.policy.default).parse(f)
    # email.parser.BytesParser class is used to create a parser instance, with the parsing policy set to email.policy.default
# parsed email message object is returned from the function

In [7]:
ham_emails = [parse_email(filepath) for filepath in ham_filenames]
# list of ham email message objects
spam_emails = [parse_email(filepath) for filepath in spam_filenames]
# list of spam email message objects

In [8]:
# Let's look at one example of ham and one example of spam, to get a feel of what the data looks like:

In [9]:
print(ham_emails[1].get_content().strip())

Martin A posted:
Tassos Papadopoulos, the Greek sculptor behind the plan, judged that the
 limestone of Mount Kerdylio, 70 miles east of Salonika and not far from the
 Mount Athos monastic community, was ideal for the patriotic sculpture. 
 
 As well as Alexander's granite features, 240 ft high and 170 ft wide, a
 museum, a restored amphitheatre and car park for admiring crowds are
planned
---------------------
So is this mountain limestone or granite?
If it's limestone, it'll weather pretty fast.

------------------------ Yahoo! Groups Sponsor ---------------------~-->
4 DVDs Free +s&p Join Now
http://us.click.yahoo.com/pt6YBB/NXiEAA/mG3HAA/7gSolB/TM
---------------------------------------------------------------------~->

To unsubscribe from this group, send an email to:
forteana-unsubscribe@egroups.com

 

Your use of Yahoo! Groups is subject to http://docs.yahoo.com/info/terms/


In [10]:
print(spam_emails[6].get_content().strip())

Help wanted.  We are a 14 year old fortune 500 company, that is
growing at a tremendous rate.  We are looking for individuals who
want to work from home.

This is an opportunity to make an excellent income.  No experience
is required.  We will train you.

So if you are looking to be employed from home with a career that has
vast opportunities, then go:

http://www.basetel.com/wealthnow

We are looking for energetic and self motivated people.  If that is you
than click on the link and fill out the form, and one of our
employement specialist will contact you.

To be removed from our link simple go to:

http://www.basetel.com/remove.html


4139vOLW7-758DoDY1425FRhM1-764SMFc8513fCsLl40


In [11]:
# Some emails are actually multipart, with images and attachments (which can have their own attachments). Let's look at the various types of structures we have:

In [12]:

def get_email_structure(email):
  if isinstance(email, str):
    # isinstance(email, str) verifies whether the provided input is a string.
    # If it is, the string is returned, assuming it represents the structure of the email.
    return email
  # An email object payload is the data being transmitted, or the body text of an email, not including any headers.
  # The get_payload() method is used to extract the payload of the email.
  payload = email.get_payload()
  if isinstance(payload, list):
    # If the payload is a list (indicating a multipart message), the function iterates through each sub-email in the payload, recursively calling the "get_email_structure" function on each sub-email.
    # The results are then joined together with a comma and labeled as a "multipart" message.
    return ", ".join([get_email_structure(sub_email) for sub_email in payload])
  else:
    # If the payload is not a list, the function returns the content type of the email using email.get_content_type().
    return email.get_content_type()
  # The function returns the identified email structure, which could be a string representation of the structure or a labeled "multipart" structure based on the payload analysis.

In [13]:
# The "collections" module is a built-in Python module that provides a variety of high-performance container datatypes, in addition to the built-in datatypes.
from collections import Counter
# Counter is used to count the occurrences of elements within an iterable or a mapping.

def structures_counter(emails):
  structures = Counter()  #  This Counter will be used to tally the occurrences of different email structures.
  for email in emails:
    structure = get_email_structure(email)  # email's structure is stored
    structures[structure] += 1
    # The occurrence of the identified email structure is then incremented within the "structures" Counter using the retrieved "structure" as the key.
    # If the structure has not been encountered before, it is automatically initialized to 0 before being incremented.
  return structures
# returns a counter containing the count of different email structures observed

In [14]:
# The most_common() method is a built-in method of the Counter class that returns a list of the n most common elements and their counts from the most common to the least.
# If no value for n is specified, it returns all elements in the Counter.
structures_counter(ham_emails).most_common()

[('text/plain', 2411),
 ('text/plain, application/pgp-signature', 66),
 ('text/plain, text/html', 8),
 ('text/plain, text/plain', 5),
 ('text/plain, application/octet-stream', 2),
 ('text/plain, text/enriched', 1),
 ('text/plain, application/ms-tnef, text/plain', 1),
 ('text/plain, text/plain, text/plain, application/pgp-signature', 1),
 ('text/plain, video/mng', 1),
 ('text/plain, application/x-pkcs7-signature', 1),
 ('text/plain, text/plain, text/plain, text/rfc822-headers', 1),
 ('text/plain, text/plain, text/plain, text/plain, application/x-pkcs7-signature',
  1),
 ('text/plain, application/x-java-applet', 1)]

In [15]:
structures_counter(spam_emails).most_common()

[('text/plain', 237),
 ('text/html', 208),
 ('text/plain, text/html', 45),
 ('text/plain, image/jpeg', 3),
 ('text/html, application/octet-stream', 2),
 ('text/plain, application/octet-stream', 1),
 ('text/html, text/plain', 1),
 ('text/html, application/octet-stream, image/jpeg', 1),
 ('text/plain, text/html, image/gif', 1),
 ('multipart/alternative', 1)]

In [16]:
# It seems that the ham emails are more often plain text, while spam has quite a lot of HTML.
# Moreover, quite a few ham emails are signed using PGP, while no spam is.
# In short, it seems that the email structure is useful information to have.

In [17]:
# Now let's take a look at the email headers:

In [18]:
# items() method returns a list of 2-tuples containing all the message’s field headers and values.
for header, value in spam_emails[0].items():
  print(header, ":", value)

Return-Path : <12a1mailbot1@web.de>
Delivered-To : zzzz@localhost.spamassassin.taint.org
Received : from localhost (localhost [127.0.0.1])	by phobos.labs.spamassassin.taint.org (Postfix) with ESMTP id 136B943C32	for <zzzz@localhost>; Thu, 22 Aug 2002 08:17:21 -0400 (EDT)
Received : from mail.webnote.net [193.120.211.219]	by localhost with POP3 (fetchmail-5.9.0)	for zzzz@localhost (single-drop); Thu, 22 Aug 2002 13:17:21 +0100 (IST)
Received : from dd_it7 ([210.97.77.167])	by webnote.net (8.9.3/8.9.3) with ESMTP id NAA04623	for <zzzz@spamassassin.taint.org>; Thu, 22 Aug 2002 13:09:41 +0100
From : 12a1mailbot1@web.de
Received : from r-smtp.korea.com - 203.122.2.197 by dd_it7  with Microsoft SMTPSVC(5.5.1775.675.6);	 Sat, 24 Aug 2002 09:42:10 +0900
To : dcek1a1@netsgo.com
Subject : Life Insurance - Why Pay More?
Date : Wed, 21 Aug 2002 20:31:57 -1600
MIME-Version : 1.0
Message-ID : <0103c1042001882DD_IT7@dd_it7>
Content-Type : text/html; charset="iso-8859-1"
Content-Transfer-Encoding : qu

In [19]:
# a lot of useful info in message’s field headers and values.
# like subject of the above email:
spam_emails[0]["Subject"]

'Life Insurance - Why Pay More?'

In [20]:
# Split the data into train and test:
import numpy as np
from sklearn.model_selection import train_test_split

X = np.array(ham_emails + spam_emails, dtype=object)
y = np.array([0] * len(ham_emails) + [1] * len(spam_emails))  # making a new column where 0 means ham and 1 means spam

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [21]:
from bs4 import BeautifulSoup
import html

'''
The following function first drops the <head> section, then converts all <a> tags to the word HYPERLINK, then it gets rid of all HTML tags, leaving only the plain text.
For readability, it also replaces multiple newlines with single newlines, and finally it unescapes html entities (such as &gt; or &nbsp;):
'''

def html_to_plain_text(html_content):
    # Create a Beautiful Soup object
    soup = BeautifulSoup(html_content, 'html.parser')

    # Drop the <head> section
    for head in soup(['head']):
        head.decompose()

    # Convert all <a> tags to the word "HYPERLINK"
    for a_tag in soup('a'):
        a_tag.string = ' HYPERLINK '

    # Extract text content and remove excess whitespaces and line breaks
    text_content = soup.get_text(separator='\n', strip=True)

    # Unescape HTML entities
    text_content = html.unescape(text_content)

    return text_content



In [22]:
# Let's see if it works. This is HTML spam:

# iterate through the emails in the training set (X_train) that are classified as spam (y_train==1) and filters for emails that have a specific structure identified as "text/html"
html_spam_emails = [email for email in X_train[y_train==1] if get_email_structure(email) == "text/html"] # list of HTML spam emails
sample_html_spam = html_spam_emails[5]
# The content of the selected HTML spam email is obtained using the get_content() method.
# The strip() method is used to remove any leading or trailing whitespace from the content, and a slice of the first 1000 characters is extracted for display.
print(sample_html_spam.get_content().strip()[:1000], "...")

<HR>
<html>
<div bgcolor="#FFFFCC">

  <p align="center"><a
href="http://www.webbasedmailing.com"><img border="0"
src="http://www.webbasedmailing.com/Toners2goLogo.jpg"
width="349" height="96"></a></p>
<p align="center"><font size="6" face="Arial MT
Black"><i>Tremendous Savings</i>
on Toners,&nbsp;</font></p>
<p align="center"><font size="6" face="Arial MT
Black">
Inkjets, FAX, and Thermal Replenishables!!</font></p>
<p><a href="http://www.webbasedmailing.com">Toners 2 Go
</a>is your secret
weapon to lowering your cost for <a
href="http://www.webbasedmailing.com">High Quality,
Low-Cost</a> printer
supplies!&nbsp; We have been in the printer
replenishables business since 1992,
and pride ourselves on rapid response and outstanding
customer service.&nbsp;
What we sell are 100% compatible replacements for
Epson, Canon, Hewlett Packard,
Xerox, Okidata, Brother, and Lexmark; products that
meet and often exceed
original manufacturer's specifications.</p>
<p><i><font size="4">Check out these
p

In [23]:
# the resulting plain text after passing it through html_to_plain_text():
print(html_to_plain_text(sample_html_spam.get_content())[:1000], "...")

HYPERLINK
Tremendous Savings
on Toners,
Inkjets, FAX, and Thermal Replenishables!!
HYPERLINK
is your secret
weapon to lowering your cost for
HYPERLINK
printer
supplies!  We have been in the printer
replenishables business since 1992,
and pride ourselves on rapid response and outstanding
customer service. 
What we sell are 100% compatible replacements for
Epson, Canon, Hewlett Packard,
Xerox, Okidata, Brother, and Lexmark; products that
meet and often exceed
original manufacturer's specifications.
Check out these
prices!
Epson Stylus
Color inkjet cartridge
(SO20108):     Epson's Price:
$27.99
Toners2Go price: $9.95!
HP
LaserJet 4 Toner Cartridge
(92298A):           
HP's
Price:
$88.99
Toners2Go
  price: $41.75!
Come visit us on the web to check out our hundreds
of similar bargains at
HYPERLINK
!
request to be excluded by visiting
HYPERLINK
beverley ...


In [24]:
# a function that takes an email as input and returns its content as plain text, whatever its format is:

def email_to_text(email):
    html = None
    for part in email.walk():
        # The walk() method is used to iterate through all the parts of a multipart message, including subparts within the message.
        ctype = part.get_content_type()
        if not ctype in ("text/plain", "text/html"):
            #  If the content type is not "text/plain" or "text/html", the loop continues to the next part.
            continue
        try:
            content = part.get_content()
        except: # in case of encoding issues
            content = str(part.get_payload()) # converted to string type
        if ctype == "text/plain":
            return content
        else:
            html = content
    if html: # content is of html type therefore pass it to html_to_plain_text()
        return html_to_plain_text(html)


In [25]:
# now we can get plain text for any format
print(email_to_text(sample_html_spam)[:100], "...")

HYPERLINK
Tremendous Savings
on Toners,
Inkjets, FAX, and Thermal Replenishables!!
HYPERLINK
is your ...


In [26]:
# Let's throw in some stemming! We will use the Natural Language Toolkit (NLTK):
# Stemming is the process of reducing words to their root or base form, known as the word stem.
# In the context of natural language processing, stemming algorithms are used to remove suffixes and prefixes from words in order to derive their base forms.
# PorterStemmer is best stemmer for word normalization and simplification which are essential for efficient text processing, analysis, and retrieval tasks.
import nltk

stemmer = nltk.PorterStemmer()
for word in ("Computations", "Computation", "Computing", "Computed", "Compute",
             "Compulsive"):
    print(word, "=>", stemmer.stem(word))

Computations => comput
Computation => comput
Computing => comput
Computed => comput
Compute => comput
Compulsive => compuls


In [27]:
# lets use the urlextract library to replace URLs with the word "URL"
%pip install -q -U urlextract

In [28]:
# urlextract libraryis used for identifying and extracting URLs from text data
import urlextract

url_extractor = urlextract.URLExtract()
some_text = "Will it detect github.com and https://youtu.be/7Pq-S557XQU?t=3m32s"
print(url_extractor.find_urls(some_text))

['github.com', 'https://youtu.be/7Pq-S557XQU?t=3m32s']


In [29]:
'''
We are ready to put all this together into a transformer that we will use to convert emails to word counters.
Note that we split sentences into words using Python's split() method, which uses whitespaces for word boundaries.
This works for many written languages, but not all.
For example, Chinese and Japanese scripts generally don't use spaces between words, and Vietnamese often uses spaces even between syllables.
It's okay here, because the dataset is (mostly) in English.
'''

from sklearn.base import BaseEstimator, TransformerMixin
import re
from collections import Counter

class EmailToWordCounterTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, strip_headers=True, lower_case=True,
                 remove_punctuation=True, replace_urls=True,
                 replace_numbers=True, stemming=True):
        # initialising parameters to control the preprocessing steps
        self.strip_headers = strip_headers
        self.lower_case = lower_case
        self.remove_punctuation = remove_punctuation
        self.replace_urls = replace_urls
        self.replace_numbers = replace_numbers
        self.stemming = stemming

    def fit(self, X, y=None):
        #  return the transformer instance
        return self

    def transform(self, X, y=None):
        X_transformed = []
        for email in X:
            text = email_to_text(email) or ""
            if self.lower_case:
                # email text is converted to lowercase
                text = text.lower()
            if self.replace_urls and url_extractor is not None:
                # URLs are replaced with the token "URL"
                urls = list(set(url_extractor.find_urls(text)))
                urls.sort(key=lambda url: len(url), reverse=True)
                for url in urls:
                    text = text.replace(url, " URL ")
            if self.replace_numbers:
                # Numbers are replaced with the token "NUMBER"
                text = re.sub(r'\d+(?:\.\d*)?(?:[eE][+-]?\d+)?', 'NUMBER', text)
            if self.remove_punctuation:
                # Punctuation is removed
                text = re.sub(r'\W+', ' ', text, flags=re.M)

            # preprocessed text is split into words, and a Counter object is used to count the occurrences of each word
            word_counts = Counter(text.split())

            if self.stemming and stemmer is not None:
                #  word counts are transformed using stemming to reduce words to their base forms
                stemmed_word_counts = Counter()
                for word, count in word_counts.items():
                    stemmed_word = stemmer.stem(word)
                    stemmed_word_counts[stemmed_word] += count #  increments the count of the stemmed word in the stemmed_word_counts Counter object by the original count of the word from the email text
                word_counts = stemmed_word_counts
            # transformed word counts, which now represent the stemmed forms of the words, are added to the X_transformed list for each email
            X_transformed.append(word_counts)
        return np.array(X_transformed)

In [30]:
# trying this transformer on a few emails:
X_few = X_train[:3]
X_few_wordcounts = EmailToWordCounterTransformer().fit_transform(X_few)
X_few_wordcounts

array([Counter({'chuck': 1, 'murcko': 1, 'wrote': 1, 'stuff': 1, 'yawn': 1, 'r': 1}),
       Counter({'the': 11, 'of': 9, 'and': 8, 'all': 3, 'christian': 3, 'to': 3, 'by': 3, 'jefferson': 2, 'i': 2, 'have': 2, 'superstit': 2, 'one': 2, 'on': 2, 'been': 2, 'ha': 2, 'half': 2, 'rogueri': 2, 'teach': 2, 'jesu': 2, 'some': 1, 'interest': 1, 'quot': 1, 'url': 1, 'thoma': 1, 'examin': 1, 'known': 1, 'word': 1, 'do': 1, 'not': 1, 'find': 1, 'in': 1, 'our': 1, 'particular': 1, 'redeem': 1, 'featur': 1, 'they': 1, 'are': 1, 'alik': 1, 'found': 1, 'fabl': 1, 'mytholog': 1, 'million': 1, 'innoc': 1, 'men': 1, 'women': 1, 'children': 1, 'sinc': 1, 'introduct': 1, 'burnt': 1, 'tortur': 1, 'fine': 1, 'imprison': 1, 'what': 1, 'effect': 1, 'thi': 1, 'coercion': 1, 'make': 1, 'world': 1, 'fool': 1, 'other': 1, 'hypocrit': 1, 'support': 1, 'error': 1, 'over': 1, 'earth': 1, 'six': 1, 'histor': 1, 'american': 1, 'john': 1, 'e': 1, 'remsburg': 1, 'letter': 1, 'william': 1, 'short': 1, 'again': 1, 'becom

In [31]:
# Now we have the word counts, and we need to convert them to vectors.
# For this, we will build another transformer whose fit() method will build the vocabulary (an ordered list of the most common words)
# and whose transform() method will use the vocabulary to convert word counts to vectors.
# The output is a sparse matrix.

The `enumerate()` function takes a collection (e.g. a tuple) and returns it as an enumerate object.
The `enumerate()` function adds a counter as the key of the enumerate object.


```
x = ('apple', 'banana', 'cherry')
y = enumerate(x)

print(list(y))
```
output:

```
[(0, 'apple'), (1, 'banana'), (2, 'cherry')]
```






In [32]:
from scipy.sparse import csr_matrix
# transform word count data into a sparse matrix representation suitable for use in machine learning models
class WordCounterToVectorTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, vocabulary_size=1000):
        self.vocabulary_size = vocabulary_size
    def fit(self, X, y=None):
        # calculates the most common words from the input word count data X and constructs a vocabulary based on the most common words
        total_count = Counter()
        for word_count in X:
            for word, count in word_count.items():
                #  increments the count of the word in the total_count Counter object
                #  min(count, 10) ensures that the count of each word does not exceed 10
                total_count[word] += min(count, 10)
                #  to limit the impact of very high frequency words in the vocabulary

        most_common = total_count.most_common()[:self.vocabulary_size]
        #  most_common method returns a list of tuples containing the most common words and their counts, ordered by frequency
        # here, sliced to include only the most common words up to the specified vocabulary_size
        self.vocabulary_ = {word: index + 1 for index, (word, count) in enumerate(most_common)}
        # Using a dictionary comprehension, self.vocabulary_ is constructed.
        # For each word and its count in the most common words list, a key-value pair is created, where the word is the key and the index (offset by 1) in the list is the value.
        # The index offset by 1 is used to reserve index 0 for out-of-vocabulary or unknown words.
        return self

    def transform(self, X, y=None):
        # converts the input word count data X into a sparse matrix representation
        # Each row of the sparse matrix corresponds to the word counts for a specific input, and each column represents a word in the vocabulary
        rows = []
        cols = []
        data = []
        for row, word_count in enumerate(X):
            for word, count in word_count.items():
                rows.append(row)
                cols.append(self.vocabulary_.get(word, 0))
                # self.vocabulary_.get(word, 0) is used to obtain the index of a word in the vocabulary. If the word is present in the vocabulary, its index is retrieved.
                # If the word is not found, the default index of 0 is returned, representing an out-of-vocabulary or unknown word.
                data.append(count)
        return csr_matrix((data, (rows, cols)),
                          shape=(len(X), self.vocabulary_size + 1))
        # returns a CSR matrix, where each row represents the word counts for a specific input, and each column represents a word in the vocabulary


In [33]:
# Using the transformers:
vocab_transformer = WordCounterToVectorTransformer(vocabulary_size=10)
X_few_vectors = vocab_transformer.fit_transform(X_few_wordcounts)
X_few_vectors

<3x11 sparse matrix of type '<class 'numpy.int64'>'
	with 20 stored elements in Compressed Sparse Row format>

In [34]:
X_few_vectors.toarray()

array([[ 6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [99, 11,  9,  8,  3,  1,  3,  1,  3,  2,  3],
       [67,  0,  1,  2,  3,  4,  1,  2,  0,  1,  0]])

What does this matrix mean? Well, the 99 in the second row, first column, means that the second email contains 99 words that are not part of the vocabulary. The 11 next to it means that the first word in the vocabulary is present 11 times in this email. The 9 next to it means that the second word is present 9 times, and so on. You can look at the vocabulary to know which words we are talking about. The first word is "the", the second word is "of", etc.

In [35]:
vocab_transformer.vocabulary_

{'the': 1,
 'of': 2,
 'and': 3,
 'to': 4,
 'url': 5,
 'all': 6,
 'in': 7,
 'christian': 8,
 'on': 9,
 'by': 10}

In [36]:
# Training the spam classifier
from sklearn.pipeline import Pipeline

preprocess_pipeline = Pipeline([
    ("email_to_wordcount", EmailToWordCounterTransformer()),
    ("wordcount_to_vector", WordCounterToVectorTransformer()),
])

X_train_transformed = preprocess_pipeline.fit_transform(X_train)

In [37]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

log_clf = LogisticRegression(max_iter=1000, random_state=42)
score = cross_val_score(log_clf, X_train_transformed, y_train, cv=3)
score.mean()

0.9858333333333333

In [38]:
# precision/recall we get on the test set:
from sklearn.metrics import precision_score, recall_score

X_test_transformed = preprocess_pipeline.transform(X_test)

log_clf = LogisticRegression(max_iter=1000, random_state=42)
log_clf.fit(X_train_transformed, y_train)

y_pred = log_clf.predict(X_test_transformed)

print(f"Precision: {precision_score(y_test, y_pred):.2%}")
print(f"Recall: {recall_score(y_test, y_pred):.2%}")

Precision: 96.88%
Recall: 97.89%
