# Prediction Personality (Algoritms for Zemanta Summer School)

In [83]:
import os
import nltk
import collections
import matplotlib.pyplot as plt
import random
from pyfm import pylibfm
import numpy
import scipy
from sklearn.metrics import auc, roc_curve, confusion_matrix
import subprocess

## Preprocessing

In [6]:
stem = nltk.PorterStemmer()

In [7]:
filename = 'reddit.csv'

In [8]:
with open(filename) as f:
    data = f.readlines()
data = [text.split("\t") for text in data]
texts = [text[3] for text in data]
labels = [text[2] for text in data]
texts = [text.strip() for text in texts]
texts = [nltk.tokenize.word_tokenize(text) for text in texts]
texts = [[word for word, tag in nltk.pos_tag(text) if tag[:2] in {"NN", "VB", "JJ", "RB"}] for text in texts]
texts = [[stem.stem(word.lower()) for word in text] for text in texts]

In [9]:
texts = [list(set(text)) for text in texts]

In [10]:
all_words = []
[[all_words.append(word) for word in text] for text in texts]
all_words = collections.Counter(all_words)

In [11]:
keep_words = set([word for word, count in all_words.items() if count > 10])

In [12]:
texts = [[word for word in text if word in keep_words] for text in texts]

In [13]:
words_for_model = list(keep_words)

In [14]:
data_for_modeling = [[1 if word in text else 0 for word in words_for_model] for text in texts]

In [15]:
labels_for_modeling = [label[2:3].upper() for label in labels]

In [16]:
all_data = [(label, text) for label, text in zip(labels_for_modeling, data_for_modeling)]

In [17]:
traning_data = all_data[:26388]
testing_data = all_data[26388:]

## Factorization Machines

In [25]:
fm = pylibfm.FM(num_factors=2, num_iter=5, verbose=True)

In [26]:
traning_texts = [text for _, text in traning_data]

In [27]:
traning_labels = [label for label, _ in traning_data]

In [28]:
traning_texts = numpy.array(traning_texts)

In [30]:
traning_texts = scipy.sparse.csr.csr_matrix(traning_texts.astype(numpy.float))

In [32]:
traning_labels = numpy.array([1 if label == "T" else 0 for label in traning_labels]).astype(numpy.float)

In [33]:
fm.fit(traning_texts,traning_labels)

Creating validation dataset of 0.01 of training for adaptive regularization
-- Epoch 1
Training log loss: 0.96024
-- Epoch 2
Training log loss: 4.76361
-- Epoch 3
Training log loss: 2.47010
-- Epoch 4
Training log loss: 0.63715
-- Epoch 5
Training log loss: 0.48159


In [35]:
test_labels = [label for label, _ in testing_data]
test_texts = [text for _, text in testing_data]

In [46]:
results = [fm.predict(scipy.sparse.csr.csr_matrix(numpy.array(text).astype(numpy.float))) for text in test_texts]

In [50]:
test_class = numpy.array([1 if label == "T" else 0 for label in test_labels]).astype(numpy.float)

In [51]:
fpr, tpr, _ = roc_curve(test_class, results, pos_label=1)
auc_value = auc(fpr, tpr)
auc_value

0.7113963809889294

In [72]:
results_class = [1 if value > 0.5 else 0 for value in results]

In [73]:
confusion_matrix(results_class, test_class)

array([[  31,  221],
       [ 156, 6189]])

## Vorpal Wabbit

In [100]:
with open('vorpal-wabbit-train.csv', 'w') as f:
    for text_class, text in traning_data:
        if text_class == "F":
            start = "1"
        else:
            start = "-1"
        f.write(start + " | "  + ", ".join([str(element) for element in text]) + "\n")

In [101]:
with open('vorpal-wabbit-test.csv', 'w') as f:
    for text_class, text in testing_data:
        if text_class == "F":
            start = "1"
        else:
            start = "-1"
        f.write(start + " | "  + ", ".join([str(element) for element in text]) + "\n")

In [102]:
subprocess.call(["vw", "vorpal-wabbit-train.csv", "-c", "--passes", "4", "-f", "model.vw", "--loss_function", "logistic"])

0

In [103]:
subprocess.call(["vw", "vorpal-wabbit-test.csv", "-t", "-i", "model.vw", "-p", "preds.txt"])

0

In [104]:
with open("preds.txt") as f:
    vw_results = f.readlines()

In [105]:
vw_results

['21.554028\n',
 '21.643206\n',
 '21.702654\n',
 '21.747244\n',
 '21.710087\n',
 '21.613478\n',
 '21.591183\n',
 '21.397966\n',
 '21.702654\n',
 '21.665499\n',
 '21.687790\n',
 '21.390532\n',
 '21.672928\n',
 '21.665497\n',
 '21.568886\n',
 '21.724947\n',
 '21.739811\n',
 '21.732380\n',
 '21.591177\n',
 '21.457413\n',
 '21.643198\n',
 '21.695221\n',
 '21.695223\n',
 '20.379837\n',
 '21.732380\n',
 '21.583752\n',
 '21.650633\n',
 '21.650635\n',
 '21.650633\n',
 '21.687792\n',
 '21.643200\n',
 '21.516870\n',
 '21.412825\n',
 '20.387274\n',
 '21.665497\n',
 '21.695221\n',
 '21.672930\n',
 '21.739813\n',
 '21.724949\n',
 '21.561460\n',
 '21.643202\n',
 '21.606043\n',
 '21.658068\n',
 '21.316217\n',
 '21.702654\n',
 '21.739813\n',
 '21.732382\n',
 '21.650633\n',
 '21.732380\n',
 '21.643202\n',
 '21.687792\n',
 '21.583752\n',
 '21.643202\n',
 '21.724949\n',
 '21.672928\n',
 '21.724949\n',
 '21.680361\n',
 '21.420256\n',
 '21.427691\n',
 '21.739813\n',
 '21.680359\n',
 '21.598612\n',
 '21.732