### Get Data

In [30]:
from io import BytesIO # So we can treat bytes as a file
import requests
import tarfile
import os

BASE_URL = "https://spamassassin.apache.org/old/publiccorpus"
FILES = ["20021010_easy_ham.tar.bz2",
         "20021010_hard_ham.tar.bz2",
         "20021010_spam.tar.bz2"]

# If the data already exists, delete then get it
from shutil import rmtree
try:
    rmtree(OUTPUT_DIR)
except FileNotFoundError:
    pass

OUTPUT_DIR = 'spam_data'

for filename in FILES:
    # use requests to get the file contents at each URL
    content = requests.get(f"{BASE_URL}/{filename}").content
    fin = BytesIO(content)
    
    # extract tarfiles to the specified output dir
    with tarfile.open(fileobj=fin, mode='r:bz2') as tf:
        tf.extractall(OUTPUT_DIR)

In [36]:
import glob, re
from typing import List
from scratch.naive_bayes import Message, NaiveBayesClassifier

path = 'spam_data/*/*'

data: List[Message] = []

In [37]:
for filename in glob.glob(path): # glob.glob returns every filename that matches the wildcarded path
    is_spam = "ham" not in filename
    
    with open(filename, errors='ignore') as email_file:
        for line in email_file:
            if line.startswith("Subject: "):
                subject = line.lstrip("Subject: ")
                data.append(Message(subject, is_spam))
                break # done with this file

### Split Data into Training and Test Sets, Train the Model

In [38]:
import random 
from scratch.machine_learning import split_data

random.seed(0)
train_messages, test_messages = split_data(data, 0.75)

model = NaiveBayesClassifier()
model.train(train_messages)

### Generate some predictions

In [39]:
from collections import Counter

predictions = [(message, model.predict(message.text))
              for message in test_messages]

# Assume that spam prob > 0.5 corresponds to spam prediction
confusion_matrix = Counter((message.is_spam, spam_probability > 0.5)
                          for message, spam_probability in predictions)
print(confusion_matrix)

Counter({(False, False): 667, (True, True): 89, (True, False): 50, (False, True): 19})
