In [1]:
import os
import email
import math
from email import policy
from tabulate import tabulate
from collections import defaultdict

class SpamNaiveBayes:
    def __init__(self, path):
        self.path = path
    
    # treina o modelo
    def fit(self, files, labels):
        self.spam = defaultdict(int)
        self.ham = defaultdict(int)
        self.total_words = self.spam_words = self.ham_words = 0
        for file in files:
            body = self.read_email(file)
            words = body.split()
            for word in words:      
                if labels[file] == 0:
                    if word in self.spam:
                        self.spam[word] += 1
                    else:
                        self.spam[word] = 1
                else:
                    if word in self.ham:
                        self.ham[word] += 1
                    else:
                        self.ham[word] = 1
        self.spam_words = sum(self.spam.values())
        self.ham_words = sum(self.ham.values())
        self.total_words = self.spam_words + self.ham_words
    
    # faz a previsão de um novo email
    def predict(self, file, labels, k=1):
        body = self.read_email(file)
        words = body.split()
        spam_prob = math.log(self.spam_words / self.total_words)
        ham_prob = math.log(self.ham_words / self.total_words)
        for word in words:
            spam_prob += math.log((self.spam[word] + k) / (self.spam_words + k*2))
            ham_prob += math.log((self.ham[word] + k) / (self.ham_words + k*2))
    
        return int(ham_prob > spam_prob)
    
    # avalia o modelo a partir de um conjunto de dados e uma função de validação
    def test(self, files, labels, k=1, function='accuracy'):
        right_prev = tp = tn = fp = fn = 0
        for file in files:
            prev = self.predict(file, labels, k)
            if prev == labels[file]:
                right_prev += 1
                if prev == 1:
                    tp += 1
                else:
                    tn += 1
            else:
                if prev == 1:
                    fp += 1
                else:
                    fn += 1
        if function == 'precision':
            return tp / (tp + fp)
        if function == 'recall':
            return tp / (tp + fn)
        return right_prev / len(files)
        
    # faz o tratamento da leitura do email
    def read_email(self, file):
        fin = open(self.path + file, encoding='ISO-8859-1')
        mail = email.message_from_file(fin, policy=policy.default)
        body = ""
        if mail.is_multipart():
            for part in mail.walk():
                ctype = part.get_content_type()
                cdispo = str(part.get('Content-Disposition'))
                if ctype == 'text/plain' and 'attachment' not in cdispo:
                    body = part.get_payload()
                    break
        else:
            body = mail.get_payload()
        return body
        

In [2]:
def train_test_split(files, split=0.20):
    sample = int(split * len(files))
    return files[sample:], files[0:sample]

def cross_validation(path, files, labels, k_values, n, function='accuracy'):    
    ps = int(len(files) / n)
    best_model = None
    best_value = 0
    for k in k_values:
        snb = SpamNaiveBayes(path)
        total_values = 0
        for i in range(0, len(files), ps):
            train, validation = files[0:i] + files[i+ps:], files[i:i+ps]
            if len(validation) == ps:
                snb.fit(train, labels)
                total_values += snb.test(validation, labels, k, function)
                
        mean = total_values / n
        print('k =', k, function, '=', mean)
        if mean > best_value:
            best_value = mean
            best_model = snb
            
    return best_model, best_value
    

In [5]:
labels = {}

label_file = open('./SPAM-DATA/SPAMTrain.label')
for row in label_file:
    label, file = row.strip().split()    
    labels[file] = int(label)

path = './SPAM-DATA/DATA/'
files = os.listdir(path)

train_mails, test_mails = train_test_split(files)

k_values = [0.01, 0.04, 0.07, 0.1, 0.4, 0.7, 1]
SpamNB, best_k = cross_validation(path, train_mails, labels, k_values=k_values, n=10)
print('Cross validation accuracy:', best_k)

k = 0.01 accuracy = 0.9771676300578035
k = 0.04 accuracy = 0.9777456647398843
k = 0.07 accuracy = 0.9774566473988437
k = 0.1 accuracy = 0.9771676300578033
k = 0.4 accuracy = 0.9745664739884393
k = 0.7 accuracy = 0.9739884393063584
k = 1 accuracy = 0.972543352601156
Cross validation accuracy: 0.9777456647398843


In [16]:
right_prev = tp = tn = fp = fn = 0
for file in test_mails:
    prev = SpamNB.predict(file, labels, best_k)
    if prev == labels[file]:
        right_prev += 1
        if prev == 1:
            tp += 1
        else:
            tn += 1
    else:
        if prev == 1:
            fp += 1
        else:
            fn += 1

accuracy = right_prev / len(test_mails)
precision = tp / (tp + fp)
recall = tp / (tp + fn)
print('Accuracy:', accuracy)
print('Precision:', precision)
print('Recall:', recall)
print(tabulate([['Real Positva', tp, fn], ['Real Negativa', fp, tn]], headers=['', 'Predita Positiva', 'Predita Negativa']))

Accuracy: 0.9710982658959537
Precision: 0.976027397260274
Recall: 0.9810671256454389
                 Predita Positiva    Predita Negativa
-------------  ------------------  ------------------
Real Positva                  570                  11
Real Negativa                  14                 270
