In [1]:
%load_ext autoreload
%autoreload 2

import re
import torch
import pickle
import pandas as pd
from util.nlp import QA, Paragraph, Tokenizer
from util.model import QaPredictModel

In [2]:
with open('model/weights.pickle', 'rb') as f:
    weights = pickle.load(f)

with open('model/vocab.pickle', 'rb') as f:
    vocab = pickle.load(f)

qas = QA(tokenizer=Tokenizer(vocab=vocab))
df = qas.load_train()

X, y = qas.get_Xy()
embedding = torch.Tensor(qas.tokenizer.vocab.embeddings)

model = QaPredictModel({
    'num_features': len(X['paragraph_features'].iloc[0][0]),
    'hidden_size': 64,
    'doc_layers': 3,
    'question_layers': 3,
    'dropout_emb': 0.3,
    'dropout_rnn': 0.3,
    'tune_partial': 0,
    'cuda': True
}, embedding, weights)

100%|██████████████████████████████████████████████████████████████████████████| 50362/50362 [01:08<00:00, 740.53doc/s]
100%|██████████████████████████████████████████████████████████████████████████| 50362/50362 [01:02<00:00, 810.82doc/s]


In [3]:
df['paragraph'] = y['paragraph_text'].values
df['answer'] = y['answer'].values
df['i'] = range(len(df))

def get_answer_starts(row):
    paragraph = Paragraph(row['paragraph'], tokenizer=qas.tokenizer)
    matches = re.finditer(re.escape(row['answer'].lower()), paragraph.text.lower())
    return [paragraph.find_answer_span(m.start(), m.start() + len(row['answer']))[0] for m in matches]

answer_starts = df.apply(get_answer_starts, axis=1)
multi_answers = answer_starts[answer_starts.map(len) > 1]
idx = df.loc[multi_answers.index]['i'].values.tolist()
predicts = list(model.predicts(X.loc[idx], y.loc[idx]))

100%|███████████████████████████████████████████████████████████████████████████████| 46/46 [00:08<00:00,  5.90batch/s]


In [4]:
df2 = pd.DataFrame({'starts': multi_answers})
df2['predict'] = predicts

def get_best_start(row):
    scores = [row['predict'].scores_start[s] for s in row['starts']]
    return scores.index(max(scores))

answer_pos = pd.DataFrame({'answer_pos': df2.apply(get_best_start, axis=1)})
answer_pos[answer_pos['answer_pos'] > 0].to_csv('data/answer_pos.csv')