In [30]:
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
})
import pandas as pd
import pickle
import text_classifier as text_clf
import sys
sys.path.append("..")
import common

In [2]:
news_train, news_test = common.load_data('news')
spam_train, spam_test = common.load_data('spam')

In [3]:
def cross_validate(X, y, k, dims, num_epochs=5, ngrams=3):
    kf = KFold(k)
    results = {D: [] for D in dims}

    for D in dims:
        for train_idx, val_idx in kf.split(X):
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]

            clf = text_clf.TextClassifier(X_train, y_train, embed_dim=D, ngrams=ngrams, num_epochs=num_epochs)

            predictions = [clf.predict(text) for text in X_val]
            acc = (predictions == y_val).sum() / y_val.shape[0]
            results[D].append(acc)
    results = {D: np.mean(accs) for D, accs in results.items()}
    return results

In [4]:
# TAKES LITERALlY FOREVER TO RUN FOR NEWS 

#for n in [3]:
#    accuracies = cross_validate(spam_train['texts'], spam_train['labels'], k=5, dims=range(1,50,4), ngrams=n)
#    pickle.dump(accuracies, open(f'spam_accs_{n}grams.p', 'wb'))

In [20]:
plt.style.use(['science'])

plt.figure(figsize=(7.5,5))
spam_acc_files = [f"CV/spam_accs_{n}grams.p" for n in [1,2,3]]
spam_acc_data = [pickle.load(open(file, 'rb')) for file in spam_acc_files]
for data, label in zip(spam_acc_data, ["unigrams", "bigrams", "trigrams"]):
    plt.plot(list(data.keys()), list(data.values()), '--', label=f"Spam {label}")

news_acc_files = [f"CV/news_accs_{n}grams.p" for n in [1,2,3]]
news_acc_data = [pickle.load(open(file, 'rb')) for file in news_acc_files]
for data, label in zip(news_acc_data, ["unigrams", "bigrams", "trigrams"]):
    plt.plot(list(data.keys()), list(data.values()), '-', label=f"News {label}")

plt.legend()
plt.xlabel('Embedding dimension')
plt.ylabel('Validation accuracy')
plt.title('Mean validation accuracy for 5-fold cross-validation\n for uni-, bi- and trigrams of fastText')
plt.tight_layout()
plt.savefig("val_accs.pdf", bbox_inches='tight')

Train and save classifier to avoid re-training (switch comments to train/save and load)

news_fasttext_classifier.p should be around 580mb

In [42]:
#spam_clf = text_clf.TextClassifier(spam_train['texts'], spam_train['labels'], embed_dim=17, ngrams=1)
#pickle.dump(spam_clf, open('spam_fasttext_classifier.p', 'wb'))
spam_clf = pickle.load(open('spam_fasttext_classifier.p', 'rb'))

In [48]:
spam_acc = accuracy_score(spam_test['labels'], [spam_clf.predict(text) for text in spam_test['texts']])

print(f'Test accuracy for spam: {round(spam_acc,3)}')

Test accuracy for spam: 0.97


In [50]:
spam_train_emb = [spam_clf.get_text_embedding(text) for text in spam_train['texts']]

In [44]:
# news_clf = text_clf.TextClassifier(news_train['texts'], news_train['labels'], embed_dim=49, ngrams=1)
# pickle.dump(news_clf, open('news_fasttext_classifier.p', 'wb'))
news_clf = pickle.load(open('news_fasttext_classifier.p', 'rb'))

In [49]:
news_acc = accuracy_score(news_test['labels'], [news_clf.predict(text) for text in news_test['texts']])
print(f'Test accuracy for news: {round(news_acc,3)}')

Test accuracy for news: 0.905


In [53]:
news_train_emb = [news_clf.get_text_embedding(text) for text in news_train['texts']]