In [1]:
# imports
import string
import numpy as np

import sys
sys.path.append("..")
import utils

from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
from sklearn.decomposition import PCA

In [2]:
# stolen from notebook9
filename = "glove.6B.50d.txt"

embeddings = {}
with open(filename,'r', encoding='utf-8') as file:
    for line in file:
        elements = line.split();
        word = elements[0];
        vector = np.asarray(elements[1:],"float32")
        embeddings[word] = vector

In [37]:
# functions
def preprocess(text, punctuation=False):
    """
    Parameters:
    text (string): raw text from data

    Returns:
    string: string with no punctuation and
            only contains words in embeddings
    """
    if punctuation:
        text_ = ""
        for char in text:
            if char not in string.punctuation:
                text_ += char
        text = text_
    
    text_ = ""
    for word in text.split():
        if word in embeddings:
             text_ += word.lower() + " "
    return text_

def mean_emb(text):
    """
    Parameters:
    text (string): raw text from data
    
    Returns:
    numpy.ndarray(50,): mean embedding of text
    """
    pre = preprocess(text).split()
    sum_emb = np.zeros(50,)
    for word in pre:
        sum_emb += embeddings[word]
    mean_emb = sum_emb/50
    
    return mean_emb

def glove_accuracy(X_train, X_test, y_train, y_test):
    train_embeddings = np.array([mean_emb(text) for text in X_train])
    test_embeddings = np.array([mean_emb(text) for text in X_test])

    clf = GaussianNB().fit(train_embeddings, y_train)

    predictions = clf.predict(test_embeddings)
    accuracy = accuracy_score(y_test, predictions)

    return accuracy

In [38]:
## Spam data

In [39]:
spam_train, spam_test = utils.load_data('spam')
glove_accuracy(spam_train['texts'], spam_test['texts'], spam_train['labels'], spam_test['labels'])

0.9336322869955157

In [33]:
news_train, news_test = utils.load_data('news')
glove_accuracy(news_train['texts'], news_test['texts'], news_train['labels'], news_test['labels'])

0.8372368421052632