### Test notebook 

In [32]:
import numpy as np

""" Get the 20 news groups data """
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
import pickle

#categories of text we will train the model on
cats = ['alt.atheism', 'sci.space']

#this is the text data
newsgroups = fetch_20newsgroups(shuffle=True, random_state=1, subset="train",
                                remove=("headers", "footers", "quotes"), categories=cats)


In [None]:
""" Prepare input for sklearn (counts) """
n_features = 1000
vectorizer = CountVectorizer(max_features=n_features, stop_words="english")

# Word counts per document matrix (input for sklearn)
W_counts = vectorizer.fit_transform(newsgroups.data)

# Keep track of vocabulary to visualize top words of each topic
vocabulary = vectorizer.get_feature_names()


In [None]:
""" Prepare input for our model (one hot vectors) """
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.stem.porter import PorterStemmer
from gensim.corpora import Dictionary
import string

# 1) Tokenize document
# 2) Remove stop words
# 3) Lemmatize
lemmatizer = WordNetLemmatizer()
#stemmer = PorterStemmer()
# TODO : treat special list better
# TODO See what count vectorizer does
special = ["''", "'s", "``", "n't", "...", "--"]
stop = set(stopwords.words('english') + \
           list(string.punctuation) + special)

def prepare_document(doc):
    words = [lemmatizer.lemmatize(w) for w in word_tokenize(doc.lower()) 
             if w not in stop] 
    return words


# List of documents (each document is a list of word tokens)
texts = [prepare_document(text) for text in newsgroups.data]

# Create a gensim dictionary for our corpus
dic = Dictionary(texts)

# Keep only k most frequent words
n_features = 1000
dic.filter_extremes(keep_n=n_features)
vocab_size = len(dic.token2id) # Vocabulary size

# List of documents (each document is now a list of word indexes)
# We have removed all words not in the k most frequent words
texts_idx = [[dic.token2id[word] for word in text 
              if word in dic.token2id] for text in texts]

# Convert each index to a one hot vector
W = [np.eye(vocab_size)[text] for text in texts_idx if len(text) > 0]

# Keep track of id to word mapping to visualize top words of topics
id2word = dict([[v, k] for k, v in dic.token2id.items()])

pickle.dump(W, open("W.p", "wb"))

1000
