# Text Classification
*Complete and hand in this completed worksheet (including its outputs and any supporting code outside of the worksheet) with your assignment submission. Please check the pdf file for more details.*

In this exercise you will:
    
- implement a of spam classifier with **Naive Bayes method** for real world email messages
- learn the **training and testing phase** for Naive Bayes classifier  
- get an idea of the **precision-recall** tradeoff

In [1]:
# some basic imports
import numpy as np
import matplotlib.pyplot as plt
import scipy.sparse
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [2]:
# ham_train contains the occurrences of each word in ham emails. 1-by-N vector
ham_train = np.loadtxt('ham_train.csv', delimiter=',')
# spam_train contains the occurrences of each word in spam emails. 1-by-N vector
spam_train = np.loadtxt('spam_train.csv', delimiter=',')
# N is the size of vocabulary.
N = ham_train.shape[0]
# There 9034 ham emails and 3372 spam emails in the training samples
num_ham_train = 9034
num_spam_train = 3372
# Do smoothing
x = np.vstack([ham_train, spam_train]) + 1

# ham_test contains the occurences of each word in each ham test email. P-by-N vector, with P is number of ham test emails.
i,j,ham_test = np.loadtxt('ham_test.txt').T
i = i.astype(np.int)
j = j.astype(np.int)
ham_test_tight = scipy.sparse.coo_matrix((ham_test, (i - 1, j - 1)))
ham_test = scipy.sparse.csr_matrix((ham_test_tight.shape[0], ham_train.shape[0]))
ham_test[:, 0:ham_test_tight.shape[1]] = ham_test_tight
# spam_test contains the occurences of each word in each spam test email. Q-by-N vector, with Q is number of spam test emails.
i,j,spam_test = np.loadtxt('spam_test.txt').T
i = i.astype(np.int)
j = j.astype(np.int)
spam_test_tight = scipy.sparse.csr_matrix((spam_test, (i - 1, j - 1)))
spam_test = scipy.sparse.csr_matrix((spam_test_tight.shape[0], spam_train.shape[0]))
spam_test[:, 0:spam_test_tight.shape[1]] = spam_test_tight


## Now let's implement a ham/spam email classifier. Please refer to the PDF file for details

In [14]:
from likelihood import likelihood
# TODO
# Implement a ham/spam email classifier, and calculate the accuracy of your classifier

# Hint: you can directly do matrix multiply between scipy.sparse.coo_matrix and numpy.array.
# Specifically, you can use sparse_matrix * np_array to do this. Note that when you use "*" operator
# between numpy array, this is typically an elementwise multiply.

# begin answer
l = likelihood(x)
# a
ratio = l[1] / l[0]
max10_idx = np.argsort(ratio)[-10:]

import linecache
for i in max10_idx:
    s = linecache.getline('all_word_map.txt', i+1).strip()
    print(s)
    print(ratio[i])


class SpamClassifier:
    def __init__(self):
        self.class_num = 2
        self.trained = False

    def train(self, x, sample_nums):
        self.likelihood = likelihood(x)
        self.prior = np.array(sample_nums) / np.sum(sample_nums)
        self.log_likelihood = np.log(self.likelihood)
        self.log_prior = np.log(self.prior)
        self.trained = True

    def __call__(self, x):
        if self.trained:
            log_posterior = x * self.log_likelihood.T + self.log_prior[np.newaxis, :]
            prediction = np.argmax(log_posterior, axis=1)
            return prediction
        else:
            print('Please train first!')

clf = SpamClassifier()
clf.train(x, [num_ham_train, num_spam_train])
ham_pred = clf(ham_test)
spam_pred = clf(spam_test)

ham_acc_num = ham_test.shape[0] - np.sum(ham_pred)
spam_acc_num = np.sum(spam_pred)
total_acc_num = ham_acc_num + spam_acc_num

ham_acc = ham_acc_num / ham_test.shape[0]
spam_acc = spam_acc_num / spam_test.shape[0]
total_acc = total_acc_num / (ham_test.shape[0] + spam_test.shape[0])

print('accuracy')
print(ham_acc_num, ham_test.shape[0], ham_acc)
print(spam_acc_num, spam_test.shape[0], spam_acc)
print(total_acc_num, ham_test.shape[0] + spam_test.shape[0], total_acc)

# end answer

ooking	9453
518.3682269968041
sex	56930
614.4894876319731
computron	13613
652.2514114529324
meds	37568
672.8488244461829
php	65398
768.9700850813518
voip	9494
837.6281283921868
cialis	45153
847.9268348888121
pills	38176
1101.9615951389017
viagra	75526
1249.5763882571969
nbsp	30033
1325.1002358991152
(3011,)
(1124,)
accuracy
9006 3011 2.991032879442046
1093 1124 0.9724199288256228
10099 4135 2.442321644498186
