In [5]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from pathlib import Path

In [6]:
tcn = ['question', 'answerA', 'answerB', 'answerC', 'answerD']

def read(path):
    td = pd.read_csv(path, '\t')
    
    td['id'] = td['id'].astype(np.uint32)
    td.set_index(['id'], inplace=True)
            
    if 'correctAnswer' in td:
        correctAnswers = td['correctAnswer']
        del td['correctAnswer']
        td.insert(0, 'correctAnswer', correctAnswers.astype('category'))
    
    return td

def clean(td):
    from nltk import word_tokenize as tokenize_words

    from nltk.corpus import stopwords
    stopwords = frozenset(stopwords.words('english'))

    from string import punctuation as punct
    punct = frozenset(punct)

    def tokenize(text):
        words = []
        for token in tokenize_words(text):
            token = token.lower()
            if token in stopwords:
                continue
            if token[-1] in punct:
                continue        
            words.append(token)
        return words
        
    for cn in tcn:
        td[cn] = td[cn].map(tokenize)

    return td

td = read('../data/training_set.tsv')
td = clean(td)

td

Unnamed: 0_level_0,correctAnswer,question,answerA,answerB,answerC,answerD
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
100001,C,"[athletes, begin, exercise, heart, rates, resp...","[tissue, level]","[organ, level]","[system, level]","[cellular, level]"
100002,C,"[example, describes, learned, behavior, dog]","[smelling, air, odors]","[barking, disturbed]","[sitting, command]","[digging, soil]"
100003,D,"[two, nuclei, combined, one, nucleus, slight, ...",[conversion],[reaction],[fission],[fusion]
100004,B,"[distinction, epidemic, pandemic]","[symptoms, disease]","[geographical, area, affected]","[species, organisms, infected]","[season, disease, spreads]"
100005,B,"[way, orbit, comet, different, orbit, earth]","[orbit, earth, less, circular, orbit, comet]","[orbit, comet, elliptical, orbit, earth]","[orbital, period, earth, much, longer, orbital...","[orbital, period, comet, predictable, orbital,..."
100006,B,"[teacher, builds, model, hydrogen, atom, red, ...","[number, particles]","[relative, mass, particles]","[types, particles, present]","[charges, particles, present]"
100007,A,"[substance, student, apply, skin, gets, splash...",[water],[vinegar],[salt],[formaldehyde]
100008,A,"[main, source, energy, water, cycle]",[sun],"[fossil, fuels]",[clouds],[ocean]
100009,D,"[greatest, effect, aiding, movement, blood, hu...",[tension],[friction],[density],[gravity]
100010,C,"[time, non-volcanic, mountains, form, due, int...","[oceanic, plates, colliding, oceanic, plates]","[oceanic, plates, separating, oceanic, plates]","[continental, plates, colliding, continental, ...","[continental, plates, separating, continental,..."


In [68]:
vs = 300

def build_feature_extractor(texts):
    from gensim.models.doc2vec import Doc2Vec, TaggedDocument
    from multiprocessing import cpu_count
    
    model = Doc2Vec(
        [TaggedDocument(t, [i]) for i, t in enumerate(texts)],
        workers=cpu_count(),

        size=vs,
    )
    
    def extract_features(text):
        return model.infer_vector(text).astype(np.float32)
    
    return extract_features


extract_features = build_feature_extractor(t for cn in tcn for t in td[cn].values)

if Path('features.pkl').exists():
    vd = pd.read_pickle('features.pkl')
else:
    tfcn_for = {cn: ['%s_feature_%d' % (cn, i) for i in range(vs)] for cn in tcn}
    tfcn = [fcn for cn in tcn for fcn in tfcn_for[cn]]
    vd = pd.DataFrame(index=td.index, columns=['correctAnswer'] + tfcn)
    vd['correctAnswer'] = td['correctAnswer']
    vd[tfcn] = vd[tfcn].astype(np.float32)
    vd = vd.to_dense()

    for i in tqdm(td.index):
        for cn in tcn:
            vd.loc[i, tfcn_for[cn]] = extract_features(td[cn][i])
    
    vd.to_pickle('features.pkl')
    
vd

Unnamed: 0_level_0,correctAnswer,question_feature_0,question_feature_1,question_feature_2,question_feature_3,question_feature_4,question_feature_5,question_feature_6,question_feature_7,question_feature_8,...,answerD_feature_290,answerD_feature_291,answerD_feature_292,answerD_feature_293,answerD_feature_294,answerD_feature_295,answerD_feature_296,answerD_feature_297,answerD_feature_298,answerD_feature_299
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
100001,C,0.000855,0.000188,0.002267,-0.002812,0.000167,0.002964,0.002544,0.000042,0.001697,...,0.001240,0.002736,0.002357,-0.000352,0.000802,-0.002615,-0.000503,-0.000841,-1.489036e-03,0.003036
100002,C,-0.000790,0.005058,-0.000329,-0.004286,0.002028,0.002607,0.002770,0.001908,-0.000589,...,-0.003991,0.006600,0.004652,-0.004603,0.017015,-0.010040,-0.002924,-0.003232,4.546267e-03,0.015956
100003,D,-0.001595,0.007671,-0.000202,-0.007401,0.000197,0.004906,0.005155,0.002872,0.000937,...,-0.005757,0.008574,0.005002,-0.002991,0.017570,-0.009561,-0.001470,-0.000340,5.301210e-03,0.013743
100004,B,0.006605,-0.013449,-0.002904,0.011699,-0.000268,-0.008524,-0.010449,-0.008188,0.002134,...,0.003441,-0.007794,-0.004918,0.004435,-0.016565,0.009826,0.003235,0.002210,-3.842054e-03,-0.011347
100005,B,0.002086,-0.006844,-0.001040,0.010403,-0.003291,-0.005652,-0.006554,-0.002266,0.000379,...,0.005001,-0.005356,-0.002903,0.004893,-0.015533,0.007716,-0.000271,0.001492,-5.058169e-03,-0.010023
100006,B,0.002376,-0.004986,-0.000102,0.007887,-0.001626,-0.003677,-0.007257,-0.004577,-0.000086,...,-0.001714,0.004627,0.001465,-0.001016,0.006152,-0.003188,-0.001718,-0.000108,2.353626e-03,0.007946
100007,A,-0.004044,0.006757,0.002373,-0.008671,-0.000376,0.003654,0.006702,0.004319,-0.000841,...,-0.000571,0.000941,-0.000366,-0.000496,0.000970,-0.001177,-0.001556,-0.000286,-3.196509e-04,0.000814
100008,A,-0.003005,0.001872,0.001531,-0.003916,-0.002227,0.002368,0.002765,0.002370,-0.000917,...,0.007710,-0.009551,-0.004396,0.005450,-0.024425,0.014873,0.003751,0.002852,-2.422908e-03,-0.020400
100009,D,0.000818,-0.000452,-0.002007,0.002078,-0.001569,-0.002296,-0.001995,-0.000676,-0.000895,...,-0.005810,0.011904,0.006619,-0.004976,0.022988,-0.013763,-0.002601,-0.002100,6.134739e-03,0.020808
100010,C,-0.000049,-0.005234,0.001276,0.005626,-0.001981,-0.003025,-0.003685,-0.001604,0.000733,...,0.001899,-0.003838,-0.002534,0.002936,-0.005535,0.002353,0.001734,0.001189,-1.892741e-03,-0.004647


In [84]:
from sklearn.base import BaseEstimator, ClassifierMixin
from scipy.spatial import distance

class MyClassifier(BaseEstimator, ClassifierMixin):
    def fit(self, X, y):
        if X.shape[0] != y.shape[0]:
            raise ValueError('X.rows != y.rows')

        self.question_essence = X[:, :vs].mean(axis=0)
        
        correct_answers = np.zeros((X.shape[0], vs))
        for i in range(X.shape[0]):
            j = ord(y[i]) - ord('A')
            o = (1 + j) * vs
            correct_answers[i, :] = X[i, o:o + vs]
        self.correct_answer_essence = correct_answers.mean(axis=0)
        
        return self

    def predict(self, X):
        y = np.empty(X.shape[0], dtype=object)
        
        for i in range(X.shape[0]):
            question = X[i, :vs]
            quasi_correct_answer = question - self.question_essence + self.correct_answer_essence
            
            answers = np.zeros((4, vs))
            for j in range(4):
                o = (1 + j) * vs
                answers[j] = X[i, o:o + vs]
            
            j = np.argmin([distance.cosine(a, quasi_correct_answer) for a in answers])
            y[i] = chr(ord('A') + j)

        return y

In [85]:
from sklearn.cross_validation import cross_val_score

v = vd.as_matrix()

mc = MyClassifier()

cross_val_score(mc, v[:, 1:], v[:, 0], cv=10, scoring='accuracy').mean()

0.24840307921117213