In [1]:
import zipfile
import requests
from collections import Counter
import re
import math
import random

In [2]:
# 데이터 다운, 압축 풀기

r = requests.get("https://www.dt.fee.unicamp.br/~tiago/smsspamcollection/smsspamcollection.zip")
with open("sms.zip", "wb") as f:
    f.write(r.content)
zipfile.ZipFile("sms.zip").extractall("./")

In [7]:
# 데이터 로드

data = []
with open("SMSSpamCollection.txt", "r") as f:
    for line in f:
        cls, txt = line.strip().split("\t")
        bow = set(re.findall("[0-9a-z_]+", txt.lower()))
        data.append([cls, bow])

In [8]:
# train/test 데이터 나누기

random.shuffle(data)
train_size = int(len(data) * 0.8)
test_size = len(data) - train_size
train = data[:train_size]
test = data[train_size:]
print(len(train), len(test)) 

4459 1115


In [38]:
# train 데이터에서, 나이브 베이즈 계산을 위한 값 미리 계산해두기
# 스팸/일반 메일의 비율 계산: P(spam), P(ham)

n_total = train_size
n_spam = sum(1 for cls, bow in train if cls == 'spam')
n_ham = sum(1 for cls, bow in train if cls == 'ham')

alpha = 0.00001

prior_spam = (n_spam + alpha) / (n_total + 2 * alpha)
prior_ham = (n_ham + alpha) / (n_total + 2 * alpha)

print(prior_spam, prior_ham)

0.13074680587240722 0.8692531941275927


In [39]:
# 단어의 스팸메일에서의 등장 비율 계산: P(w | spam), P(w | ham)

spam_words = Counter(word for cls, bow in train for word in bow if cls == 'spam')
ham_words = Counter(word for cls, bow in train for word in bow if cls == 'ham')

In [40]:
# 스팸 분류 함수
def predict(bow):
    spam_score = prior_spam
    ham_score = prior_ham
    for word in bow:
        spam_score *= (spam_words[word] + alpha) / (n_spam + alpha)
        ham_score *= (ham_words[word] + alpha) / (n_ham + alpha)

    if spam_score < ham_score:
        return "ham"
    else:
        return "spam"

In [41]:
# 테스트
tp, tn, fp, fn = 0,0,0,0
for ans, bow in test:
    pred = predict(bow)
    if pred == 'spam':
        if ans == 'spam':
            tp += 1
        else: 
            fp += 1
    else:
        if ans == 'ham':
            tn += 1
        else:
            fn += 1

print(f"tp: {tp}, tn: {tn}, fp: {fp}, fn: {fn}")

tp: 155, tn: 924, fp: 27, fn: 9


In [42]:
# accuracy: 전체 결과를 맞춘 비율
acc = (tp + tn) / (tp + tn + fp + fn)
# precision: 스팸으로 예측한 결과 중 맞춘 비율
prec = tp / (tp + fp)
# recall: 실제 스팸 중 스팸으로 예측한 비율
recall = tp / (tp + fn)
# f1: precision 과 recall의 조화평균
f1 = 2 * prec * recall / (prec + recall)

print(f"acc: {acc:.3f}, prec: {prec:.3f}, recall: {recall:.3f}, f1: {f1:.3f}")

acc: 0.968, prec: 0.852, recall: 0.945, f1: 0.896
