In [75]:
import numpy as np

In [76]:
# Filenames
EXT =  "plt_large"
ID_MAP_FILE = "id_map_%s.pkl" % EXT
EMBED_FILE = "embed_%s.npz" % EXT

SYMBOLS_TO_REPLACE = '!"#$%&()*+,./:;<=>?@[\]^_`{|}~-' + "'"
SYMBOLS_TO_REMOVE = '"#$%&()*+,/:;<=>@[\]^_`{|}~-' + "'?!"

In [77]:
embed = []
with np.load(EMBED_FILE) as f:
     embed = f['embed']
embed.shape

(61273, 128)

In [78]:
import pickle
def load_obj(name ):
    with open(name, 'rb') as f:
        return pickle.load(f)
id_map = load_obj(ID_MAP_FILE)

print "Vocab Length:", len(id_map)

Vocab Length: 61273


In [79]:
def tokenize_text(words, id_map):
    ''' Convert cleaned text into list of tokens/word-ids.
    Args:
      id_map: Hashmap from word to id. id of 0 is always 'UNK' unkown token.
    Ret: list of ints (ids)
    '''
    words = words.split()
    tokens = []
    for i, word in enumerate(words):
        id = 0 # 0 is UNK id
        word = word.lower()
        if word in id_map:
            id = id_map[word]
        tokens.append(id)

    return np.array(tokens)

def clean_text(text_raw):
    text = text_raw
    if isinstance(text_raw, list):
        text = ' '.join(sentence for sentence in text_raw)
    text = filter(lambda x: x not in SYMBOLS_TO_REMOVE, text)
    text = text.split(". ")
    return text

def clean_sentence(sentence_raw):
    text = filter(lambda x: x not in (SYMBOLS_TO_REMOVE +'.'), sentence_raw)
    return text

In [80]:
import MovieQA
mqa = MovieQA.DataLoader()
story_raw, qa = mqa.get_story_qa_data('train', 'plot')

Initialized MovieQA data loader!


In [81]:
story = {}
story_lines = {}
for imdb_key in story_raw:
    tk_sent = []
    story_lines[imdb_key] = clean_text(story_raw[imdb_key])
    for line in story_lines[imdb_key]:
        tk_sent.append(tokenize_text(line, id_map))
    story[imdb_key] = tk_sent
    #story_lines[imdb_key] =  text
    
print story[imdb_key]
print story_lines[imdb_key]

[array([42,  0]), array([  935,     2,  1135,   145,  3069,     9,  9433,  1083,    89,
          16,  2626,     9,   892, 18022,  3268,    16,  1744,    91,
        3195,    33,   339,    10, 18023,    64,    10, 13356,  2316]), array([   10,  1881,  4780,    78,     0,  6377,   103,   161,    33,
        1419,   566,    42,    10, 11902,  2571,  1045,   106,    16,
         135, 14897,  3745,    42,    10,  3901]), array([  756,  1451, 18024,     9,   672,    23,  4024,    12,    13,
         257,  2184,    16,    13, 13281,    33,    46,   566,    98,
       17532,  3092,    78,     0]), array([  323,  2997,    10,  1646, 13281,    78,    10,  1882,   180,
         181,  2829,    73,    33,   310, 10750,    12,  6596,    16,
         280,    10, 10661,  1356,  1744,     9,    10,   219,    33,
        5049,   284,   248,   161,  5326,  2380,    78,  3268]), array([   10, 18022,    91,  3420,    42,    80,  4692, 18025,    16,
        1744,     9, 10067,    78,    23,  5731,  2184,  

In [82]:
def normalize(mat):
    if mat.ndim == 1:
        return mat/np.linalg.norm(mat)
    return mat/np.linalg.norm(mat, axis=1, keepdims=True)

for key in story:
    plot_vects = np.array([np.average(embed[sentence], axis=0) for sentence in story[key] if np.all(sentence!='')])
    plot_vects = normalize(plot_vects)
    story[key] = plot_vects



In [83]:
nCorrect = 0
nTried = len(qa)
for q in qa:
    # Process Question
    question = clean_sentence(q.question) 
    question = np.average(embed[tokenize_text(question, id_map)], axis=0)
    question = normalize(question)

    # Process answers
    answers = []
    for a in q.answers:
        if a == '':
            continue
        a_clean = clean_sentence(a)
        answers.append(np.average(embed[tokenize_text(a_clean, id_map)], axis=0))
    answers = normalize(np.array(answers))
    
    # Calculate similarity
    qscore = story[q.imdb_key].dot(question).reshape(-1, 1)
    ascore = story[q.imdb_key].dot(answers.T)
    score = ascore + qscore
    prediction = np.unravel_index(score.argmax(), score.shape)
    '''print score[prediction], np.amax(score)
    print prediction
    print q.correct_index'''
    if prediction[1] == q.correct_index:
        nCorrect+=1
        '''
        print "CORRRECTTTTTT!!!!"
    else:
        print "Incorrect D:"
    print "Question: ", q.question
    print "Correct: ", q.answers[q.correct_index]
    print 
    print "Ref: ", story_lines[q.imdb_key][prediction[0]]
    print "Model Chose: ", q.answers[prediction[1]]
    print '------------------' #'''

In [74]:
print(plot_vects[0].shape)
print story_lines[q.imdb_key]

(128,)
['Gary Hook a new recruit to the British Army takes leave of his much younger brother Darren', 'Hooks squad of British soldiers is sent to Belfast in 1971 in the early years of The Troubles', 'Under the leadership of the inexperienced Lieutenant Armitage his squad goes to a volatile area of Belfast where Catholic Nationalists and Protestant Loyalists live side by side', 'The unit provides support for the Royal Ulster Constabulary as it inspects homes for firearms shocking Hook with their rough treatment of women and children', 'The Catholic neighbourhood has been alerted to the activity and a crowd gathers to protest and provoke the British troops who though heavily armed can only respond by trying to hold the crowd back', 'One soldier leaves his gun on the ground in the confusion and a young boy runs off through the mob with it Hook and another pursue him', 'As the crowds protest escalates into rockthrowing the soldiers and police pull out leaving the two soldiers behind', 'Hoo

In [84]:
float(nCorrect)/nTried

0.4123679935012185