In [1]:
import os
import nltk
import math
import random
import collections

In [2]:
parent = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
parent = parent.split('/')
parent.remove(parent[-1])
parent = '/'.join(parent)
categories = ['news', 'entertainment', 'sports', 'fun']

In [3]:
def fileids(category):
    path = os.path.join(parent, 'corpus', 'processed', category)
    return os.listdir(path)

def words(file):
    f = open(file, 'r').read().strip()
    sents = [sent.split(' ') for sent in f.split('\n')]
    words = [word for sent in sents for word in sent if len(word) > 0]
    return words

In [4]:
documents, total_words = [], []
for category in categories:
    for fileid in fileids(category):
        path = os.path.join(parent, 'corpus', 'processed',
                            category, fileid)
        w = words(path)
        documents.append((w, category))
        total_words.extend(w)
print('Word count:', len(total_words))
random.shuffle(documents)

Word count: 500369


In [5]:
all_words = set(total_words)
word_features = list(all_words)
print('Vocab size:', len(all_words))

Vocab size: 37446


In [6]:
def features(document):
    document_words = set(document)
    features = {}
    for word in word_features:
        features['contains({})'.format(word)] = (word in document_words)
    return features

In [7]:
feature_sets = [(features(d), c) for (d, c) in documents]
cutoff = math.ceil(len(feature_sets) * 0.7)
train_set, test_set = feature_sets[:cutoff], feature_sets[cutoff:]

In [8]:
print('Training Set')
training_counts = [y for (x, y) in train_set]
for k, v in collections.Counter(training_counts).items():
    print(k, v)
print()
print('Test Set')
test_counts = [y for (x, y) in test_set]
for k, v in collections.Counter(test_counts).items():
    print(k, v)

Training Set
entertainment 3
news 2
fun 4
sports 5

Test Set
entertainment 2
news 3
fun 1


In [9]:
classifier = nltk.NaiveBayesClassifier.train(train_set)
print('Accuracy:', nltk.classify.accuracy(classifier, test_set))
classifier.show_most_informative_features()

Accuracy: 0.8333333333333334
Most Informative Features
         contains(derek) = False             fun : sports =      3.6 : 1.0
        contains(eagles) = False             fun : sports =      3.6 : 1.0
      contains(redskins) = False             fun : sports =      3.6 : 1.0
   contains(first-round) = False             fun : sports =      3.6 : 1.0
       contains(trading) = False             fun : sports =      3.6 : 1.0
        contains(harris) = False             fun : sports =      3.6 : 1.0
     contains(garoppolo) = False             fun : sports =      3.6 : 1.0
          contains(pats) = False             fun : sports =      3.6 : 1.0
        contains(desean) = False             fun : sports =      3.6 : 1.0
          contains(mary) = True              fun : sports =      3.6 : 1.0
