In [55]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns;
sns.set(style="ticks", color_codes=True)

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer

import numpy as np
import torch
import pyro
import pyro.distributions as dist
from torch.nn.functional import softplus

# dataset

we will use the 20 newsgroups dataset. example usage of this data can be found [here](http://scikit-learn.org/stable/datasets/index.html#the-20-newsgroups-text-dataset).

In [2]:
categories = [
    'alt.atheism', 'talk.religion.misc', 
    'comp.graphics', 'sci.space',  
    'rec.sport.baseball', 'rec.sport.hockey'
]
docs_train = fetch_20newsgroups(subset='train', categories=categories)
docs_test = fetch_20newsgroups(subset='test', categories=categories)
num_features = 5000

## getting a better vocab

extract the top words from the following subgroups. sklearn is being stupid and wont let me fetch 20 newsgroups of a single topic (it only returns ~5).

In [56]:
docs_train_relig = fetch_20newsgroups(subset='train', remove=('headers', 'footers'), categories=['alt.atheism', 'talk.religion.misc'])
relig_vectz = TfidfVectorizer(stop_words='english')
relig_vect = relig_vectz.fit_transform(docs_train_relig.data)

docs_train_sport = fetch_20newsgroups(subset='train', remove=('headers', 'footers'), categories=['rec.sport.baseball', 'rec.sport.hockey'])
sport_vectz = TfidfVectorizer(stop_words='english')
sport_vect = sport_vectz.fit_transform(docs_train_sport.data)

docs_train_sci = fetch_20newsgroups(subset='train', remove=('headers', 'footers'), categories=['comp.graphics', 'sci.space'])
sci_vectz = TfidfVectorizer(stop_words='english')
sci_vect = sport_vectz.fit_transform(docs_train_sci.data)

In [57]:
relig_tuplist = sorted(relig_vectz.vocabulary_.items(), key=lambda tup: tup[1])
relig_tuplist.reverse()
dict(relig_tuplist)

{'zyklon': 16314,
 'zxmkr08': 16313,
 'zus': 16312,
 'zurvanism': 16311,
 'zur': 16310,
 'zumder': 16309,
 'zues': 16308,
 'zuck': 16307,
 'zubin': 16306,
 'zoroastrians': 16305,
 'zoroastrianism': 16304,
 'zoroastrian': 16303,
 'zoroaster': 16302,
 'zoro': 16301,
 'zorastrian': 16300,
 'zoo': 16299,
 'zone': 16298,
 'zombie': 16297,
 'zlumber': 16296,
 'zip': 16295,
 'zionist': 16294,
 'zion': 16293,
 'zillions': 16292,
 'zeus02': 16291,
 'zeus': 16290,
 'zeroed': 16289,
 'zero': 16288,
 'zen': 16287,
 'zeitgeist': 16286,
 'zeit': 16285,
 'zechariah': 16284,
 'zech': 16283,
 'zebras': 16282,
 'zealous': 16281,
 'zealots': 16280,
 'zeal': 16279,
 'zdv': 16278,
 'zc': 16277,
 'zazen': 16276,
 'zarathushtra': 16275,
 'zappa': 16274,
 'zakat': 16273,
 'zakariyah': 16272,
 'zahlah': 16271,
 'zach': 16270,
 'za': 16269,
 'z1dan': 16268,
 'z1': 16267,
 'yusuf': 16266,
 'yup': 16265,
 'yukky': 16264,
 'yugoslavian': 16263,
 'yugoslavia': 16262,
 'yucky': 16261,
 'ysu': 16260,
 'yoyo': 16259,


In [42]:
mega_vect.vocabulary_

{'jbh55289': 22933,
 'uxa': 40942,
 'cso': 13382,
 'uiuc': 40168,
 'edu': 15987,
 'josh': 23250,
 'hopkins': 20961,
 'subject': 37674,
 'griffin': 19739,
 'office': 28972,
 'exploration': 17221,
 'rip': 33959,
 'article': 7636,
 'news': 28196,
 'c51r3o': 10230,
 '9wk': 5385,
 'organization': 29304,
 'university': 40534,
 'illinois': 21539,
 'urbana': 40769,
 'lines': 24991,
 '23': 2697,
 'yamauchi': 42805,
 'ces': 11073,
 'cwru': 13565,
 'brian': 9711,
 'writes': 42554,
 'comments': 12119,
 'absorbtion': 5914,
 'space': 36646,
 'sciences': 34994,
 'reassignment': 32867,
 'chief': 11345,
 'engineer': 16445,
 'position': 31216,
 'just': 23423,
 'meaningless': 26344,
 'administrative': 6243,
 'shuffle': 35826,
 'does': 15292,
 'bode': 9326,
 'ill': 21533,
 'sei': 35273,
 'unfortunately': 40443,
 'things': 38962,
 'boding': 9331,
 'legitimate': 24669,
 'conjugation': 12529,
 'great': 19671,
 'ideas': 21417,
 'got': 19494,
 'money': 27249,
 've': 41130,
 'heard': 20372,
 'good': 19450,
 'ha

### feature extraction

see wikipedia for an explaination of [tf-idf](https://en.wikipedia.org/wiki/Tf%E2%80%93idf)

In [3]:
vectorizer = TfidfVectorizer(stop_words='english', max_features=num_features, binary=True, use_idf=False, norm=None)
vectors_train = vectorizer.fit_transform(docs_train.data).todense()
vectors_test = vectorizer.transform(docs_test.data).todense()
print('train: {}'.format(vectors_train.shape))
print('test: {}'.format(vectors_test.shape))

train: (3231, 5000)
test: (2149, 5000)


In [29]:
# see https://youtu.be/oAihxFkRHu8?t=29m34s
# for an explaination of N_c and N_c_j
# 
N_c = np.array([
    vectors_train[docs_train.target == 0, :].shape[0],
    vectors_train[docs_train.target == 1, :].shape[0],
    vectors_train[docs_train.target == 2, :].shape[0],
    vectors_train[docs_train.target == 3, :].shape[0],
    vectors_train[docs_train.target == 4, :].shape[0],
    vectors_train[docs_train.target == 5, :].shape[0]
])

def pi_estimator():
    """
    number of times a document occures in class divided by total
    number of docs
    nando calls it N_c
    """
    return N_c / float(len(vectors_train))

def theta_estimator():
    """
    each class has 1000 estimated parameters
    """
    acc = []
    for class_lbl in range(6):
        feat_acc = []
        for feat in range(num_features):
            feat_acc.append(vectors_train[docs_train.target == class_lbl, feat].sum())
        feat_acc = np.array(feat_acc) / float(N_c[class_lbl])
        acc.append(feat_acc)
    return np.array(acc)

pi_hat = pi_estimator()
theta_hat = theta_estimator()

def prob_c_given(x_str):
    """
    https://youtu.be/oAihxFkRHu8?t=29m44s
    """
    def pmf_at_c(c):
        x = vectorizer.transform([x_str]).todense().tolist()[0]
        acc_prod = 1
        for j in range(num_features):
            acc_prod *= theta_hat[c,j]**(x[j] == 1) * (1 - theta_hat[c,j])**(x[j] == 0)
        return pi_hat[c] * acc_prod
    
    pmfs = []
    for lbl in [0,1,2,3,4,5]:
        pmfs.append(pmf_at_c(lbl))
    pmfs = np.array(pmfs)
    pmfsum = pmfs.sum()
    if pmfsum == 0:
        return np.array([1.,0.,0.,0.,0.,0.])
    else:
        return pmfs / pmfsum

def show_test_idx(idx):
    # docs_test.target_names
    pred = np.argmax(prob_c_given(docs_test.data[idx]))
    print('pred: {}'.format(docs_test.target_names[pred]))
    print('target: {}'.format(docs_test.target_names[docs_test.target[idx]]))
    print('\n** doc **')
    print(docs_test.data[idx])

def test_accuracy():
    testN = len(docs_test.target)
    correct = 0
    curr_cnt = 0
    for i in range(testN):
        pred = np.argmax(prob_c_given(docs_test.data[i]))
        target = docs_test.target[i]
        if pred == target:
            correct += 1
        curr_cnt += 1
        if curr_cnt % 100 == 0:
            print('Accuracy (0-1): {}'.format((correct / float(curr_cnt)) * 100))
    print('Accuracy Final (0-1): {}'.format((correct / float(testN)) * 100))
        
def test_accuracy_random():
    testN = len(docs_test.target)
    correct = 0
    curr_cnt = 0
    for i in range(testN):
        pred = np.argmax(np.random.multinomial(1, [1/6.]*6, size=1)[0])
        target = docs_test.target[i]
        if pred == target:
            correct += 1
        curr_cnt += 1
        if curr_cnt % 100 == 0:
            print('Accuracy (0-1): {}'.format((correct / float(curr_cnt)) * 100))
    print('Accuracy Final (0-1): {}'.format((correct / float(testN)) * 100))
# test_accuracy_random()

Accuracy (0-1): 14.000000000000002
Accuracy (0-1): 17.0
Accuracy (0-1): 16.666666666666664
Accuracy (0-1): 15.5
Accuracy (0-1): 16.6
Accuracy (0-1): 15.833333333333332
Accuracy (0-1): 16.857142857142858
Accuracy (0-1): 16.75
Accuracy (0-1): 16.555555555555557
Accuracy (0-1): 16.400000000000002
Accuracy (0-1): 16.363636363636363
Accuracy (0-1): 16.166666666666664
Accuracy (0-1): 16.307692307692307
Accuracy (0-1): 16.785714285714285
Accuracy (0-1): 16.53333333333333
Accuracy (0-1): 16.5
Accuracy (0-1): 16.58823529411765
Accuracy (0-1): 16.38888888888889
Accuracy (0-1): 16.0
Accuracy (0-1): 15.7
Accuracy (0-1): 16.095238095238095
Accuracy Final (0-1): 15.960912052117262


# model

In [24]:
np.argmax(np.random.multinomial(1, [1/6.]*6, size=1)[0])

1