In [80]:
import math, random, re, pathlib, os
from collections import Counter, defaultdict


def get_data():
    data = []
    emaildir = pathlib.Path.cwd() / 'emails'
    for dirpath, dirs, files in os.walk(emaildir, topdown=False):
        is_spam = 'spam' in dirpath
        for file in files:
            with open(pathlib.Path(dirpath) / file, 'rb') as email:
                for line in email:
                    line = line.decode('utf-8', 'ignore')
                    if line.startswith('Subject:'):
                        line = line.replace('Subject:', '')
                        line = line.strip()
                        data.append((line, is_spam))
    return data
                    
    
def split(data, split_fraction=0.7):
    shuffled_data = random.sample(data, k=len(data))
    split_ix = int(split_fraction * len(data))
    return shuffled_data[:split_ix], shuffled_data[split_ix:]
    
    
def counts(data):
    cts = Counter([label for _, label in data])
    return cts[True], cts[False]
    
    
def tokenize(message):
    message = message.lower()
    words = re.findall(r"([a-z0-9']+)", message)
    return set(words)
    
    
def word_counts(data):
    wcounts = defaultdict(lambda: [0,0])
    for message, is_spam in data:
        words = tokenize(message)
        for word in words:
            wcounts[word][0 if is_spam else 1] += 1
    return wcounts
    
    
def word_probabilities(wcounts, tot_spam, tot_nonspam, k=0.5):
    # word, p(word | spam), p(word | nonspam)
    # laplace smoothing
    return [
        (word, (ct_spam + k)/(tot_spam + 2*k), (ct_nonspam + k)/(tot_nonspam + 2*k))
        for word, (ct_spam, ct_nonspam) in wcounts.items()
    ]
    
    
def nbclassify(wprobs, message):
    # p(spam | message) = 
    #      p(message | spam)*p(spam) / 
    #           [p(message | spam)*p(spam) + p(message | nonspam)*p(nonspam)]
    log_prob_spam = log_prob_nonspam = 0.0
    message_words = tokenize(message)
    for word, pspam, pnonspam in wprobs:
        if word in message_words:
            log_prob_spam += math.log(pspam)
            log_prob_nonspam += math.log(pnonspam)
        else:
            log_prob_spam += math.log(1-pspam)
            log_prob_nonspam += math.log(1-pnonspam)
    tot_prob_spam = math.exp(log_prob_spam)
    tot_prob_nonspam = math.exp(log_prob_nonspam)
    return tot_prob_spam / (tot_prob_spam + tot_prob_nonspam)
    
    
def runtest(wprobs, data, threshold=0.5):
    results = []
    for message, is_spam in data:
        p_spam = nbclassify(wprobs, message)
        results.append((is_spam, p_spam >= threshold))
    return Counter(results)


In [86]:
data = get_data()
training, test = split(data)
tot_spam, tot_nonspam = counts(training)
wcounts = word_counts(training)
wprobs = word_probabilities(wcounts, tot_spam, tot_nonspam, k=0.5)
runtest(wprobs, test)

Counter({(False, False): 831,
         (True, True): 94,
         (True, False): 63,
         (False, True): 39})