Skip to content

Commit

Permalink
ready.
Browse files Browse the repository at this point in the history
  • Loading branch information
rajarshd committed Mar 12, 2019
0 parents commit 25590af
Show file tree
Hide file tree
Showing 51 changed files with 6,504 additions and 0 deletions.
76 changes: 76 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Multi Step Reasoning for Open Domain Question Answering

To-do
- [ ] Integrate with code for SGTree

![gif](image/multi-step-reasoner.png)
Code for the paper [Multi-step Retriever-Reader Interaction for Scalable Open-domain Question Answering](https://openreview.net/forum?id=HkfPSh05K7)

*Acknowledgement*: This codebase started from the awesome [Dr.QA repository](https://github.com/facebookresearch/DrQA) created and maintained by [Adam Fisch](https://people.csail.mit.edu/fisch/). Thanks Adam!

## Setup
The requirements are in the [requirements file](requirements.txt). In my env, I also needed to set PYTHONPATH (as in the [setup.sh](setup.sh))
```
pip install -r requirements.txt
source setup.sh
```

## Data
We are making the pre-processed data and paragraph vectors available so that is is easier to get started. They can downloaded from [here](http://iesl.cs.umass.edu/downloads/multi-step-reasoning-iclr19/data.tar.gz). (41GB compressed, 56GB decompressed). If you need the pretrained paragraph encoder used to generate the vectors, feel free to get in touch with me.
After un-taring, you will find a directory corresponding to each dataset. Each directory further contains:
```
data/ -- Processed data (*.pkl files)
paragraph_vectors/ -- Saved paragraph vectors of context for each dataset used for nearest-neighbor search
vocab/ -- int2str mapping
embeddings/ -- Saved lookup table for faster initialization. The embeddings are essentially saved fast-text embeddings.
```

## Paragraph encoder
If you want to train new paragraph embeddings instead of using the ones we used, please refer to this [readme](paragraph_encoder/README.md)


## Training
```
python scripts/reader/train.py --data_dir <path-to-downloaded-data> --model_dir <path-to-downloaded-model> --dataset_name searchqa|triviaqa\quasart --saved_para_vectors_dir <path-to-downloaded-data>/dataset_name/paragraph_vectors/web-open
```
Some important command line args
```
dataset_name -- searchqa|triviaqa|quasart
data_dir -- path to dataset that you downloaded
model_dir -- path where model would be checkpointed
saved_para_vectors_dir -- path to cached paragraph and query representations in disk. It should be in the data you have downloaded
multi_step_reasoning_steps -- Number of steps of interaction between retriever and reader
num_positive_paras -- (Relevant during training) -- Number of "positive" (wrt distant supervision) paragraphs fed to train to the reader model.
num_paras_test -- (Relevant during inference time) -- Number of paragraphs to be sent to the reader by the retriever.
freeze_reader -- when set to 1, the reader parameters are fixed and only the parameters of the GRU (multi-step-reasoner) is trained.
fine_tune_RL -- fune tune the GRU (multi-step-reasoner) with reward (F1) from the fixed reader
```
Training details:
1. During training, we first train the reader model by setting ```multi_step_reasoning_steps = 1```
2. After the reader has been trained, we fix the reader and just pretrain the ```multi-step-reasoner``` (```freeze_reader 1```)
3. Next, we fine tune the reasoner with reinforcement learning (```freeze_reader = 1, fine_tune_RL = 1```)

In our experiments for searchqa and quasart, we found step 2 (pretraining the GRU was not important) and the reasoner was directly able to learn via RL. However, pretraining never hurt the performance as well.

## Pretrained models

We are also providing pretrained models for download and scripts to run them directly. Download the pretrained models from [here](http://iesl.cs.umass.edu/downloads/multi-step-reasoning-iclr19/models.tar.gz).
```
Usage: /bin/bash run_pretrained_models.sh dataset_name data_dir model_dir out_dir
dataset_name -- searchqa|triviaqa|quasart
data_dir -- path to dataset that you downloaded
model_dir -- path to pretrained model that you downloaded
out_dir -- directory for logging
```

## Citation
```
@inproceedings{
das2018multistep,
title={Multi-step Retriever-Reader Interaction for Scalable Open-domain Question Answering},
author={Rajarshi Das and Shehzaad Dhuliawala and Manzil Zaheer and Andrew McCallum},
booktitle={ICLR},
year={2019},
}
```

Binary file added image/multi-step-reasoner.pdf
Binary file not shown.
Binary file added image/multi-step-reasoner.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions msr/reader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
Binary file added msr/reader/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file added msr/reader/__pycache__/config.cpython-36.pyc
Binary file not shown.
Binary file added msr/reader/__pycache__/data.cpython-36.pyc
Binary file not shown.
Binary file added msr/reader/__pycache__/layers.cpython-36.pyc
Binary file not shown.
Binary file not shown.
Binary file added msr/reader/__pycache__/model.cpython-36.pyc
Binary file not shown.
Binary file added msr/reader/__pycache__/predictor.cpython-36.pyc
Binary file not shown.
Binary file added msr/reader/__pycache__/rnn_reader.cpython-36.pyc
Binary file not shown.
Binary file added msr/reader/__pycache__/utils.cpython-36.pyc
Binary file not shown.
Binary file added msr/reader/__pycache__/vector.cpython-36.pyc
Binary file not shown.
133 changes: 133 additions & 0 deletions msr/reader/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Model architecture/optimization options for DrQA document reader."""

import argparse
import logging

logger = logging.getLogger(__name__)

# Index of arguments concerning the core model architecture
MODEL_ARCHITECTURE = {
'model_type', 'embedding_dim', 'hidden_size', 'doc_layers',
'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge',
'use_qemb', 'use_in_question', 'use_pos', 'use_ner', 'use_lemma', 'use_tf'
}

# Index of arguments concerning the model optimizer/training
MODEL_OPTIMIZER = {
'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay',
'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb',
'max_len', 'grad_clipping', 'tune_partial'
}


def str2bool(v):
return v.lower() in ('yes', 'true', 't', '1', 'y')


def add_model_args(parser):
parser.register('type', 'bool', str2bool)

# Model architecture
model = parser.add_argument_group('DrQA Reader Model Architecture')
model.add_argument('--model-type', type=str, default='rnn',
help='Model architecture type')
model.add_argument('--embedding-dim', type=int, default=300,
help='Embedding size if embedding_file is not given')
model.add_argument('--hidden-size', type=int, default=128,
help='Hidden size of RNN units')
model.add_argument('--doc-layers', type=int, default=3,
help='Number of encoding layers for document')
model.add_argument('--question-layers', type=int, default=3,
help='Number of encoding layers for question')
model.add_argument('--rnn-type', type=str, default='lstm',
help='RNN type: LSTM, GRU, or RNN')
model.add_argument('--top-spans', type=int, default=10,
help='aggregate ascores over spans')

# Model specific details
detail = parser.add_argument_group('DrQA Reader Model Details')
detail.add_argument('--concat-rnn-layers', type='bool', default=True,
help='Combine hidden states from each encoding layer')
detail.add_argument('--question-merge', type=str, default='self_attn',
help='The way of computing the question representation')
detail.add_argument('--use-qemb', type='bool', default=True,
help='Whether to use weighted question embeddings')
detail.add_argument('--use-in-question', type='bool', default=True,
help='Whether to use in_question_* features')
detail.add_argument('--use-pos', type='bool', default=True,
help='Whether to use pos features')
detail.add_argument('--use-ner', type='bool', default=True,
help='Whether to use ner features')
detail.add_argument('--use-lemma', type='bool', default=True,
help='Whether to use lemma features')
detail.add_argument('--use-tf', type='bool', default=True,
help='Whether to use term frequency features')

# Optimization details
optim = parser.add_argument_group('DrQA Reader Optimization')
optim.add_argument('--dropout-emb', type=float, default=0.4,
help='Dropout rate for word embeddings')
optim.add_argument('--dropout-rnn', type=float, default=0.4,
help='Dropout rate for RNN states')
optim.add_argument('--dropout-rnn-output', type='bool', default=True,
help='Whether to dropout the RNN output')
optim.add_argument('--optimizer', type=str, default='adamax',
help='Optimizer: sgd or adamax')
optim.add_argument('--learning-rate', type=float, default=0.1,
help='Learning rate for SGD only')
optim.add_argument('--grad-clipping', type=float, default=10,
help='Gradient clipping')
optim.add_argument('--weight-decay', type=float, default=0,
help='Weight decay factor')
optim.add_argument('--momentum', type=float, default=0,
help='Momentum factor')
optim.add_argument('--fix-embeddings', type='bool', default=True,
help='Keep word embeddings fixed (use pretrained)')
optim.add_argument('--tune-partial', type=int, default=0,
help='Backprop through only the top N question words')
optim.add_argument('--rnn-padding', type='bool', default=False,
help='Explicitly account for padding in RNN encoding')
optim.add_argument('--max-len', type=int, default=15,
help='The max span allowed during decoding')


def get_model_args(args):
"""Filter args for model ones.
From a args Namespace, return a new Namespace with *only* the args specific
to the model architecture or optimization. (i.e. the ones defined here.)
"""
global MODEL_ARCHITECTURE, MODEL_OPTIMIZER
required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER
arg_values = {k: v for k, v in vars(args).items() if k in required_args}
return argparse.Namespace(**arg_values)


def override_model_args(old_args, new_args):
"""Set args to new parameters.
Decide which model args to keep and which to override when resolving a set
of saved args and new args.
We keep the new optimation, but leave the model architecture alone.
"""
global MODEL_OPTIMIZER
old_args, new_args = vars(old_args), vars(new_args)
for k in new_args.keys():
if k in old_args and old_args[k] != new_args[k]:
# if k in MODEL_OPTIMIZER:
logger.info('Overriding saved %s: %s --> %s' %
(k, old_args[k], new_args[k]))
old_args[k] = new_args[k]
# else:
# logger.info('Keeping saved %s: %s' % (k, old_args[k]))
elif k not in old_args:
logger.info("Adding new argument {}".format(k))
old_args[k] = new_args[k]
return argparse.Namespace(**old_args)
152 changes: 152 additions & 0 deletions msr/reader/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Data processing/loading helpers."""

import numpy as np
import logging
import unicodedata

from torch.utils.data import Dataset
from torch.utils.data.sampler import Sampler
from .vector import vectorize
import json
import os
logger = logging.getLogger(__name__)


# ------------------------------------------------------------------------------
# Dictionary class for tokens.
# ------------------------------------------------------------------------------


class Dictionary(object):
NULL = '<NULL>'
UNK = '<UNK>'
START = 2

@staticmethod
def normalize(token):
return unicodedata.normalize('NFD', token)

def __init__(self, args):
self.args = args
self.tok2ind = {self.NULL: 0, self.UNK: 1}
self.ind2tok = {0: self.NULL, 1: self.UNK}

def __len__(self):
return len(self.tok2ind)

def __iter__(self):
return iter(self.tok2ind)

def __contains__(self, key):
if type(key) == int:
return key in self.ind2tok
elif type(key) == str:
return self.normalize(key) in self.tok2ind

def __getitem__(self, key):
if type(key) == int:
return self.ind2tok.get(key, self.UNK)
if type(key) == str:
return self.tok2ind.get(self.normalize(key),
self.tok2ind.get(self.UNK))

def __setitem__(self, key, item):
if type(key) == int and type(item) == str:
self.ind2tok[key] = item
elif type(key) == str and type(item) == int:
self.tok2ind[key] = item
else:
raise RuntimeError('Invalid (key, item) types.')

def add(self, token):
token = self.normalize(token)
if token not in self.tok2ind:
index = len(self.tok2ind)
self.tok2ind[token] = index
self.ind2tok[index] = token
def save(self):

fout = open(os.path.join(self.args.vocab_dir, "ind2tok.json"), "w")
json.dump(self.ind2tok, fout)
fout.close()
fout = open(os.path.join(self.args.vocab_dir, "tok2ind.json"), "w")
json.dump(self.tok2ind, fout)
fout.close()
logger.info("Dictionary saved at {}".format(self.args.vocab_dir))

def tokens(self):
"""Get dictionary tokens.
Return all the words indexed by this dictionary, except for special
tokens.
"""
tokens = [k for k in self.tok2ind.keys()
if k not in {'<NULL>', '<UNK>'}]
return tokens


# ------------------------------------------------------------------------------
# PyTorch dataset class for SQuAD (and SQuAD-like) data.
# ------------------------------------------------------------------------------

class ReaderDataset(Dataset):

def __init__(self, args, examples, word_dict, feature_dict, single_answer=False, train_time=False):
self.args = args
self.word_dict = word_dict
self.feature_dict = feature_dict
self.examples = examples
# make a list of qids, so that we can iterate over efficiently
self.qids = list(examples.questions.keys())
self.single_answer = single_answer
self.train_time = train_time


def __len__(self):
return len(self.examples.questions)

def __getitem__(self, index):

question = self.examples.questions[self.qids[index]]
paragraphs = [self.examples.paragraphs[pid] for pid in question.pids]

return vectorize(self.args, question, paragraphs, self.word_dict, self.feature_dict, self.single_answer,
train_time=self.train_time)

def lengths(self):
return [(len(ex['document']), len(ex['question']))
for ex in self.examples]


# ------------------------------------------------------------------------------
# PyTorch sampler returning batched of sorted lengths (by doc and question).
# ------------------------------------------------------------------------------


class SortedBatchSampler(Sampler):

def __init__(self, lengths, batch_size, shuffle=True):
self.lengths = lengths
self.batch_size = batch_size
self.shuffle = shuffle

def __iter__(self):
lengths = np.array(
[(-l[0], -l[1], np.random.random()) for l in self.lengths],
dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)]
)
indices = np.argsort(lengths, order=('l1', 'l2', 'rand'))
batches = [indices[i:i + self.batch_size]
for i in range(0, len(indices), self.batch_size)]
if self.shuffle:
np.random.shuffle(batches)
return iter([i for batch in batches for i in batch])

def __len__(self):
return len(self.lengths)
Loading

0 comments on commit 25590af

Please sign in to comment.