In [9]:
import collections,math

#單純貝氏模型

def tokenize(message):
    message=message.lower()
    all_words= re.findall("[a-z0-9]+",message)
    return set(all_words) #移除重複

def count_words(training_set):
    counts = collections.defaultdict(lambda: [0,0])
    for message, is_spam in training_set:
        for word in tokenize(message):
            counts[word][0 if is_spam else 1] +=1
    return counts

def word_probabilities(counts,total_spams,total_non_spams,k=0.5):
    ##轉換成 w,p(w|spam) and p(w|~spam)
    
    return [(w,
             (spam+k)/(total_spams+2*k),
             (non_spam+k)/(total_non_spams+2*k))
           for w,(spam,non_spam) in counts.iteritems()]

def spam_probability(word_probs,message):
    message_words=tokenize(message)
    log_prob_if_spam=log_prob_if_not_spam=0.0
    
    for word,prob_if_spam,prob_if_not_spam in word_probs:
        if word in message_words:
            log_prob_if_spam+=math.log(prob_if_spam)
            log_prob_if_not_spam+=math.log(prob_if_not_spam)
            
        else:
            log_prob_if_spam+=math.log(1.0-prob_if_spam)
            log_prob_if_not_spam+=math.log(1.0-prob_if_not_spam)
    
    prob_if_spam=math.exp(log_prob_if_spam)
    prob_if_not_spam=math.exp(log_prob_if_not_spam)
    return prob_if_spam / (prob_if_spam+prob_if_not_spam)


#貝式分類器
class NaiveBayesClassifier:
    def __init__(self,k=0.5):
        self.k=k
        self.word_probs=[]
        
    def train(self,training_set):
        
        num_spams=len([is_spam
                       for message, is_spam in training_set
                       if is_spam])
        num_non_spams=len(training_set)- num_spams
        
        word_counts=count_words(training_set)
        self.word_probs=word_probabilities(word_counts,
                                          num_spams,
                                          num_non_spams,
                                          self.k)
    def classify(self,message):
        return spam_probability(self.word_probs,message)

##


In [2]:
import glob,re

path= r"C:\Users\tan\Documents\Spam\*\*"

data=[]

for fn in glob.glob(path):
    is_spam="ham" not in fn

    with open(fn,'r') as file:
        for line in file:
            if line.startswith("Subject:"):
                subject = re.sub(r"^Subject: ","",line).strip()
                data.append((subject,is_spam))


In [15]:
import random
from collections import Counter

#第十一章的func
def split_data(data,prob):
    results=[],[]
    for row in data:
        results[0 if random.random()<prob else 1].append(row)
    return results


random.seed(0)
train_data,test_data=split_data(data,0.75)

classifier=NaiveBayesClassifier()
classifier.train(train_data)

#
classified=[(subject,is_spam,classifier.classify(subject))
           for subject,is_spam in test_data]


counts=Counter((is_spam,spam_probability>0.5)
              for _,is_spam,spam_probability in classified)

print counts

Counter({(False, False): 709, (True, True): 99, (True, False): 40, (False, True): 28})


In [18]:
classified.sort(key=lambda row:row[2])

spammiest_hams=filter(lambda row: not row[1], classified)[-5:]

hammiest_spams=filter(lambda row:row[1],classified)[:5]

print spammiest_hams
print ###
print hammiest_spams


def p_spam_given_word(word_prob):
    
    word,prob_if_spam,prob_if_not_spam=word_prob
    return prob_if_spam/(prob_if_spam+prob_if_not_spam)

words=sorted(classifier.word_probs,key=p_spam_given_word)

spammiest_words=words[-5:]
hammiest_words=words[:5]

print ###
print spammiest_words
print ###
print hammiest_words

[('2000+ year old Greek computer reinterpreted', False, 0.9849593308072281), ('What to look for in your next smart phone (Tech Update)', False, 0.9880531102637597), ('Your NEW "Leg-Up" on Wall Street...', False, 0.9919460757537283), ('[ILUG-Social] Re: Important - reenactor insurance needed', False, 0.9995884053884736), ('[ILUG-Social] Re: Important - reenactor insurance needed', False, 0.9995884053884736)]

[('Re: girls', True, 0.001076252533943303), ('Introducing Chase Platinum for Students with a 0% Introductory APR', True, 0.0013774569702223898), ('.Message report from your contact page....//ytu855 rkq', True, 0.0016561226234365182), ('Testing a system, please delete', True, 0.0029503997104894856), ('Never pay for the goodz again (8SimUgQ)', True, 0.006476931884025096)]

[('adv', 0.026027397260273973, 0.00022893772893772894), ('year', 0.028767123287671233, 0.00022893772893772894), ('sale', 0.031506849315068496, 0.00022893772893772894), ('systemworks', 0.036986301369863014, 0.000228

In [19]:
##改善
## 1. 設定出現詞頻上下限
## 2. 加入判斷文法的stemmer    ex. Porter stemmer
## scikit-learn 裡有BernoulliNB模型