-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 25590af
Showing
51 changed files
with
6,504 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Multi Step Reasoning for Open Domain Question Answering | ||
|
||
To-do | ||
- [ ] Integrate with code for SGTree | ||
|
||
![gif](image/multi-step-reasoner.png) | ||
Code for the paper [Multi-step Retriever-Reader Interaction for Scalable Open-domain Question Answering](https://openreview.net/forum?id=HkfPSh05K7) | ||
|
||
*Acknowledgement*: This codebase started from the awesome [Dr.QA repository](https://github.com/facebookresearch/DrQA) created and maintained by [Adam Fisch](https://people.csail.mit.edu/fisch/). Thanks Adam! | ||
|
||
## Setup | ||
The requirements are in the [requirements file](requirements.txt). In my env, I also needed to set PYTHONPATH (as in the [setup.sh](setup.sh)) | ||
``` | ||
pip install -r requirements.txt | ||
source setup.sh | ||
``` | ||
|
||
## Data | ||
We are making the pre-processed data and paragraph vectors available so that is is easier to get started. They can downloaded from [here](http://iesl.cs.umass.edu/downloads/multi-step-reasoning-iclr19/data.tar.gz). (41GB compressed, 56GB decompressed). If you need the pretrained paragraph encoder used to generate the vectors, feel free to get in touch with me. | ||
After un-taring, you will find a directory corresponding to each dataset. Each directory further contains: | ||
``` | ||
data/ -- Processed data (*.pkl files) | ||
paragraph_vectors/ -- Saved paragraph vectors of context for each dataset used for nearest-neighbor search | ||
vocab/ -- int2str mapping | ||
embeddings/ -- Saved lookup table for faster initialization. The embeddings are essentially saved fast-text embeddings. | ||
``` | ||
|
||
## Paragraph encoder | ||
If you want to train new paragraph embeddings instead of using the ones we used, please refer to this [readme](paragraph_encoder/README.md) | ||
|
||
|
||
## Training | ||
``` | ||
python scripts/reader/train.py --data_dir <path-to-downloaded-data> --model_dir <path-to-downloaded-model> --dataset_name searchqa|triviaqa\quasart --saved_para_vectors_dir <path-to-downloaded-data>/dataset_name/paragraph_vectors/web-open | ||
``` | ||
Some important command line args | ||
``` | ||
dataset_name -- searchqa|triviaqa|quasart | ||
data_dir -- path to dataset that you downloaded | ||
model_dir -- path where model would be checkpointed | ||
saved_para_vectors_dir -- path to cached paragraph and query representations in disk. It should be in the data you have downloaded | ||
multi_step_reasoning_steps -- Number of steps of interaction between retriever and reader | ||
num_positive_paras -- (Relevant during training) -- Number of "positive" (wrt distant supervision) paragraphs fed to train to the reader model. | ||
num_paras_test -- (Relevant during inference time) -- Number of paragraphs to be sent to the reader by the retriever. | ||
freeze_reader -- when set to 1, the reader parameters are fixed and only the parameters of the GRU (multi-step-reasoner) is trained. | ||
fine_tune_RL -- fune tune the GRU (multi-step-reasoner) with reward (F1) from the fixed reader | ||
``` | ||
Training details: | ||
1. During training, we first train the reader model by setting ```multi_step_reasoning_steps = 1``` | ||
2. After the reader has been trained, we fix the reader and just pretrain the ```multi-step-reasoner``` (```freeze_reader 1```) | ||
3. Next, we fine tune the reasoner with reinforcement learning (```freeze_reader = 1, fine_tune_RL = 1```) | ||
|
||
In our experiments for searchqa and quasart, we found step 2 (pretraining the GRU was not important) and the reasoner was directly able to learn via RL. However, pretraining never hurt the performance as well. | ||
|
||
## Pretrained models | ||
|
||
We are also providing pretrained models for download and scripts to run them directly. Download the pretrained models from [here](http://iesl.cs.umass.edu/downloads/multi-step-reasoning-iclr19/models.tar.gz). | ||
``` | ||
Usage: /bin/bash run_pretrained_models.sh dataset_name data_dir model_dir out_dir | ||
dataset_name -- searchqa|triviaqa|quasart | ||
data_dir -- path to dataset that you downloaded | ||
model_dir -- path to pretrained model that you downloaded | ||
out_dir -- directory for logging | ||
``` | ||
|
||
## Citation | ||
``` | ||
@inproceedings{ | ||
das2018multistep, | ||
title={Multi-step Retriever-Reader Interaction for Scalable Open-domain Question Answering}, | ||
author={Rajarshi Das and Shehzaad Dhuliawala and Manzil Zaheer and Andrew McCallum}, | ||
booktitle={ICLR}, | ||
year={2019}, | ||
} | ||
``` | ||
|
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+20.2 KB
...eader/__pycache__/model.cpython-36 (blake2.cs.umass.edu's conflicted copy 2018-07-30).pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
"""Model architecture/optimization options for DrQA document reader.""" | ||
|
||
import argparse | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# Index of arguments concerning the core model architecture | ||
MODEL_ARCHITECTURE = { | ||
'model_type', 'embedding_dim', 'hidden_size', 'doc_layers', | ||
'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge', | ||
'use_qemb', 'use_in_question', 'use_pos', 'use_ner', 'use_lemma', 'use_tf' | ||
} | ||
|
||
# Index of arguments concerning the model optimizer/training | ||
MODEL_OPTIMIZER = { | ||
'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay', | ||
'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb', | ||
'max_len', 'grad_clipping', 'tune_partial' | ||
} | ||
|
||
|
||
def str2bool(v): | ||
return v.lower() in ('yes', 'true', 't', '1', 'y') | ||
|
||
|
||
def add_model_args(parser): | ||
parser.register('type', 'bool', str2bool) | ||
|
||
# Model architecture | ||
model = parser.add_argument_group('DrQA Reader Model Architecture') | ||
model.add_argument('--model-type', type=str, default='rnn', | ||
help='Model architecture type') | ||
model.add_argument('--embedding-dim', type=int, default=300, | ||
help='Embedding size if embedding_file is not given') | ||
model.add_argument('--hidden-size', type=int, default=128, | ||
help='Hidden size of RNN units') | ||
model.add_argument('--doc-layers', type=int, default=3, | ||
help='Number of encoding layers for document') | ||
model.add_argument('--question-layers', type=int, default=3, | ||
help='Number of encoding layers for question') | ||
model.add_argument('--rnn-type', type=str, default='lstm', | ||
help='RNN type: LSTM, GRU, or RNN') | ||
model.add_argument('--top-spans', type=int, default=10, | ||
help='aggregate ascores over spans') | ||
|
||
# Model specific details | ||
detail = parser.add_argument_group('DrQA Reader Model Details') | ||
detail.add_argument('--concat-rnn-layers', type='bool', default=True, | ||
help='Combine hidden states from each encoding layer') | ||
detail.add_argument('--question-merge', type=str, default='self_attn', | ||
help='The way of computing the question representation') | ||
detail.add_argument('--use-qemb', type='bool', default=True, | ||
help='Whether to use weighted question embeddings') | ||
detail.add_argument('--use-in-question', type='bool', default=True, | ||
help='Whether to use in_question_* features') | ||
detail.add_argument('--use-pos', type='bool', default=True, | ||
help='Whether to use pos features') | ||
detail.add_argument('--use-ner', type='bool', default=True, | ||
help='Whether to use ner features') | ||
detail.add_argument('--use-lemma', type='bool', default=True, | ||
help='Whether to use lemma features') | ||
detail.add_argument('--use-tf', type='bool', default=True, | ||
help='Whether to use term frequency features') | ||
|
||
# Optimization details | ||
optim = parser.add_argument_group('DrQA Reader Optimization') | ||
optim.add_argument('--dropout-emb', type=float, default=0.4, | ||
help='Dropout rate for word embeddings') | ||
optim.add_argument('--dropout-rnn', type=float, default=0.4, | ||
help='Dropout rate for RNN states') | ||
optim.add_argument('--dropout-rnn-output', type='bool', default=True, | ||
help='Whether to dropout the RNN output') | ||
optim.add_argument('--optimizer', type=str, default='adamax', | ||
help='Optimizer: sgd or adamax') | ||
optim.add_argument('--learning-rate', type=float, default=0.1, | ||
help='Learning rate for SGD only') | ||
optim.add_argument('--grad-clipping', type=float, default=10, | ||
help='Gradient clipping') | ||
optim.add_argument('--weight-decay', type=float, default=0, | ||
help='Weight decay factor') | ||
optim.add_argument('--momentum', type=float, default=0, | ||
help='Momentum factor') | ||
optim.add_argument('--fix-embeddings', type='bool', default=True, | ||
help='Keep word embeddings fixed (use pretrained)') | ||
optim.add_argument('--tune-partial', type=int, default=0, | ||
help='Backprop through only the top N question words') | ||
optim.add_argument('--rnn-padding', type='bool', default=False, | ||
help='Explicitly account for padding in RNN encoding') | ||
optim.add_argument('--max-len', type=int, default=15, | ||
help='The max span allowed during decoding') | ||
|
||
|
||
def get_model_args(args): | ||
"""Filter args for model ones. | ||
From a args Namespace, return a new Namespace with *only* the args specific | ||
to the model architecture or optimization. (i.e. the ones defined here.) | ||
""" | ||
global MODEL_ARCHITECTURE, MODEL_OPTIMIZER | ||
required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER | ||
arg_values = {k: v for k, v in vars(args).items() if k in required_args} | ||
return argparse.Namespace(**arg_values) | ||
|
||
|
||
def override_model_args(old_args, new_args): | ||
"""Set args to new parameters. | ||
Decide which model args to keep and which to override when resolving a set | ||
of saved args and new args. | ||
We keep the new optimation, but leave the model architecture alone. | ||
""" | ||
global MODEL_OPTIMIZER | ||
old_args, new_args = vars(old_args), vars(new_args) | ||
for k in new_args.keys(): | ||
if k in old_args and old_args[k] != new_args[k]: | ||
# if k in MODEL_OPTIMIZER: | ||
logger.info('Overriding saved %s: %s --> %s' % | ||
(k, old_args[k], new_args[k])) | ||
old_args[k] = new_args[k] | ||
# else: | ||
# logger.info('Keeping saved %s: %s' % (k, old_args[k])) | ||
elif k not in old_args: | ||
logger.info("Adding new argument {}".format(k)) | ||
old_args[k] = new_args[k] | ||
return argparse.Namespace(**old_args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
"""Data processing/loading helpers.""" | ||
|
||
import numpy as np | ||
import logging | ||
import unicodedata | ||
|
||
from torch.utils.data import Dataset | ||
from torch.utils.data.sampler import Sampler | ||
from .vector import vectorize | ||
import json | ||
import os | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# ------------------------------------------------------------------------------ | ||
# Dictionary class for tokens. | ||
# ------------------------------------------------------------------------------ | ||
|
||
|
||
class Dictionary(object): | ||
NULL = '<NULL>' | ||
UNK = '<UNK>' | ||
START = 2 | ||
|
||
@staticmethod | ||
def normalize(token): | ||
return unicodedata.normalize('NFD', token) | ||
|
||
def __init__(self, args): | ||
self.args = args | ||
self.tok2ind = {self.NULL: 0, self.UNK: 1} | ||
self.ind2tok = {0: self.NULL, 1: self.UNK} | ||
|
||
def __len__(self): | ||
return len(self.tok2ind) | ||
|
||
def __iter__(self): | ||
return iter(self.tok2ind) | ||
|
||
def __contains__(self, key): | ||
if type(key) == int: | ||
return key in self.ind2tok | ||
elif type(key) == str: | ||
return self.normalize(key) in self.tok2ind | ||
|
||
def __getitem__(self, key): | ||
if type(key) == int: | ||
return self.ind2tok.get(key, self.UNK) | ||
if type(key) == str: | ||
return self.tok2ind.get(self.normalize(key), | ||
self.tok2ind.get(self.UNK)) | ||
|
||
def __setitem__(self, key, item): | ||
if type(key) == int and type(item) == str: | ||
self.ind2tok[key] = item | ||
elif type(key) == str and type(item) == int: | ||
self.tok2ind[key] = item | ||
else: | ||
raise RuntimeError('Invalid (key, item) types.') | ||
|
||
def add(self, token): | ||
token = self.normalize(token) | ||
if token not in self.tok2ind: | ||
index = len(self.tok2ind) | ||
self.tok2ind[token] = index | ||
self.ind2tok[index] = token | ||
def save(self): | ||
|
||
fout = open(os.path.join(self.args.vocab_dir, "ind2tok.json"), "w") | ||
json.dump(self.ind2tok, fout) | ||
fout.close() | ||
fout = open(os.path.join(self.args.vocab_dir, "tok2ind.json"), "w") | ||
json.dump(self.tok2ind, fout) | ||
fout.close() | ||
logger.info("Dictionary saved at {}".format(self.args.vocab_dir)) | ||
|
||
def tokens(self): | ||
"""Get dictionary tokens. | ||
Return all the words indexed by this dictionary, except for special | ||
tokens. | ||
""" | ||
tokens = [k for k in self.tok2ind.keys() | ||
if k not in {'<NULL>', '<UNK>'}] | ||
return tokens | ||
|
||
|
||
# ------------------------------------------------------------------------------ | ||
# PyTorch dataset class for SQuAD (and SQuAD-like) data. | ||
# ------------------------------------------------------------------------------ | ||
|
||
class ReaderDataset(Dataset): | ||
|
||
def __init__(self, args, examples, word_dict, feature_dict, single_answer=False, train_time=False): | ||
self.args = args | ||
self.word_dict = word_dict | ||
self.feature_dict = feature_dict | ||
self.examples = examples | ||
# make a list of qids, so that we can iterate over efficiently | ||
self.qids = list(examples.questions.keys()) | ||
self.single_answer = single_answer | ||
self.train_time = train_time | ||
|
||
|
||
def __len__(self): | ||
return len(self.examples.questions) | ||
|
||
def __getitem__(self, index): | ||
|
||
question = self.examples.questions[self.qids[index]] | ||
paragraphs = [self.examples.paragraphs[pid] for pid in question.pids] | ||
|
||
return vectorize(self.args, question, paragraphs, self.word_dict, self.feature_dict, self.single_answer, | ||
train_time=self.train_time) | ||
|
||
def lengths(self): | ||
return [(len(ex['document']), len(ex['question'])) | ||
for ex in self.examples] | ||
|
||
|
||
# ------------------------------------------------------------------------------ | ||
# PyTorch sampler returning batched of sorted lengths (by doc and question). | ||
# ------------------------------------------------------------------------------ | ||
|
||
|
||
class SortedBatchSampler(Sampler): | ||
|
||
def __init__(self, lengths, batch_size, shuffle=True): | ||
self.lengths = lengths | ||
self.batch_size = batch_size | ||
self.shuffle = shuffle | ||
|
||
def __iter__(self): | ||
lengths = np.array( | ||
[(-l[0], -l[1], np.random.random()) for l in self.lengths], | ||
dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)] | ||
) | ||
indices = np.argsort(lengths, order=('l1', 'l2', 'rand')) | ||
batches = [indices[i:i + self.batch_size] | ||
for i in range(0, len(indices), self.batch_size)] | ||
if self.shuffle: | ||
np.random.shuffle(batches) | ||
return iter([i for batch in batches for i in batch]) | ||
|
||
def __len__(self): | ||
return len(self.lengths) |
Oops, something went wrong.