Skip to content

Commit

Permalink
Demo support
Browse files Browse the repository at this point in the history
  • Loading branch information
seominjoon committed Oct 29, 2018
1 parent ecb3f93 commit 4f07a6b
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 22 deletions.
5 changes: 5 additions & 0 deletions squad/base/argument_parser.py
Expand Up @@ -31,6 +31,7 @@ def add_arguments(self):
self.add_argument('--question_emb_dir', type=str, default=None)
self.add_argument('--context_emb_dir', type=str, default=None)

# Training arguments
self.add_argument('--epochs', type=int, default=20)
self.add_argument('--train_steps', type=int, default=0)
self.add_argument('--eval_steps', type=int, default=1000)
Expand All @@ -41,6 +42,9 @@ def add_arguments(self):
self.add_argument('--nlist', type=int, default=1)
self.add_argument('--nprobe', type=int, default=1)

# Demo arguments
self.add_argument('--port', type=int, default=8080)

# Other arguments
self.add_argument('--draft', default=False, action='store_true')
self.add_argument('--cuda', default=False, action='store_true')
Expand All @@ -49,6 +53,7 @@ def add_arguments(self):
self.add_argument('--archive', default=False, action='store_true')
self.add_argument('--dump_period', type=int, default=20)
self.add_argument('--emb_type', type=str, default='dense', help='dense|sparse')
self.add_argument('--metadata', default=False, action='store_true')

def parse_args(self, **kwargs):
args = super().parse_args()
Expand Down
20 changes: 15 additions & 5 deletions squad/base/file_interface.py
Expand Up @@ -96,7 +96,7 @@ def question_emb(self, id_, emb, emb_type='dense'):
path = os.path.join(self._question_emb_dir, '%s.npz' % id_)
savez(path, emb)

def context_emb(self, id_, phrases, emb, emb_type='dense'):
def context_emb(self, id_, phrases, emb, metadata=None, emb_type='dense'):
if not os.path.exists(self._context_emb_dir):
os.makedirs(self._context_emb_dir)
savez = scipy.sparse.save_npz if emb_type == 'sparse' else np.savez_compressed
Expand All @@ -113,20 +113,30 @@ def context_emb(self, id_, phrases, emb, emb_type='dense'):
with open(json_path, 'w') as fp:
json.dump(phrases, fp)

def context_load(self, emb_type='dense'):
if metadata is not None:
metadata_path = os.path.join(self._context_emb_dir, '%s.metadata' % id_)
with open(metadata_path, 'w') as fp:
json.dump(metadata, fp)

def context_load(self, metadata=False, emb_type='dense'):
paths = os.listdir(self._context_emb_dir)
json_paths = tuple(os.path.join(self._context_emb_dir, path)
for path in paths if os.path.splitext(path)[1] == '.json')
npz_paths = tuple('%s.npz' % os.path.splitext(path)[0] for path in json_paths)
for json_path, npz_path in zip(json_paths, npz_paths):
metadata_paths = tuple('%s.metadata' % os.path.splitext(path)[0] for path in json_paths)
for json_path, npz_path, metadata_path in zip(json_paths, npz_paths, metadata_paths):
with open(json_path, 'r') as fp:
phrases = json.load(fp)
if emb_type == 'dense':
emb = np.load(npz_path)['arr_0']
else:
emb = scipy.sparse.load_npz(npz_path)
yield phrases, emb

if metadata:
with open(metadata_path, 'r') as fp:
metadata = json.load(fp)
yield phrases, emb, metadata
else:
yield phrases, emb

def archive(self):
if self._mode == 'embed' or self._mode == 'embed_context':
Expand Down
4 changes: 3 additions & 1 deletion squad/baseline/processor.py
Expand Up @@ -193,7 +193,9 @@ def postprocess_context(self, example, context_output):
phrases = tuple(_get_pred(context, context_spans, yp1, yp2) for yp1, yp2 in pos_tuple)
if self._emb_type == 'sparse':
out = csc_matrix(out)
return example['cid'], phrases, out
metadata = {'context': context,
'answer_spans': tuple((context_spans[yp1][0], context_spans[yp2][1]) for yp1, yp2 in pos_tuple)}
return example['cid'], phrases, out, metadata

def postprocess_context_batch(self, dataset, model_input, context_output):
results = tuple(self.postprocess_context(dataset[idx], context_output[i])
Expand Down
66 changes: 51 additions & 15 deletions squad/main.py
Expand Up @@ -233,8 +233,10 @@ def embed(args):
context_output = model.get_context(**test_batch)
context_results = processor.postprocess_context_batch(test_dataset, test_batch, context_output)

for id_, phrases, matrix in context_results:
interface.context_emb(id_, phrases, matrix, emb_type=args.emb_type)
for id_, phrases, matrix, metadata in context_results:
if not args.metadata:
metadata = None
interface.context_emb(id_, phrases, matrix, metadata=metadata, emb_type=args.emb_type)

if args.mode == 'embed' or args.mode == 'embed_question':

Expand Down Expand Up @@ -279,27 +281,37 @@ def serve(args):
with torch.no_grad():
model.eval()
if args.mode == 'serve_demo':
from flask import Flask, request, jsonify
from flask_cors import CORS

from tornado.wsgi import WSGIContainer
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop

if args.emb_type == 'dense':
import faiss
print('Loading phrase-vector pairs')
d = 4 * args.hidden_size * args.num_heads
phrases = []
embs = []
for cur_phrases, emb in interface.context_load(emb_type=args.emb_type):
results = []
for cur_phrases, emb, metadata in interface.context_load(metadata=True, emb_type=args.emb_type):
phrases.extend(cur_phrases)
embs.append(emb)
for span in metadata['answer_spans']:
results.append([metadata['context'], span[0], span[1]])
emb = np.concatenate(embs, 0)

index = faiss.IndexFlatIP(d) # Exact Search
search_index = faiss.IndexFlatIP(d) # Exact Search

if args.nlist != args.nprobe:
# Approximate Search. nlist > nprobe makes it faster and less accurate
index = faiss.IndexIVFFlat(index, d, args.nlist, faiss.METRIC_INNER_PRODUCT)
index.train(emb)
index.add(emb)
index.nprobe = args.nprobe
search_index = faiss.IndexIVFFlat(search_index, d, args.nlist, faiss.METRIC_INNER_PRODUCT)
search_index.train(emb)
search_index.add(emb)
search_index.nprobe = args.nprobe
else:
index.add(emb)
search_index.add(emb)

def retrieve(question, k):
example = {'question': question, 'id': 'real', 'idx': 0}
Expand All @@ -309,14 +321,38 @@ def retrieve(question, k):
question_output = model.get_question(**batch)
question_results = processor.postprocess_question_batch(dataset, batch, question_output)
id_, emb = question_results[0]
D, I = index.search(emb, k)
out = [phrases[i] for i in I[0]]
D, I = search_index.search(emb, k)
out = [tuple(results[i]) + ('%.4r' % d.item(),) for d, i in zip(D[0], I[0])]
return out
else:
raise NotImplementedError()

# Demo server. Requires flask and tornado

app = Flask(__name__, static_url_path='/static')

app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
CORS(app)

@app.route('/')
def index():
return app.send_static_file('index.html')

@app.route('/files/<path:path>')
def static_files(path):
return app.send_static_file('files/' + path)

@app.route('/api', methods=['GET'])
def api():
query = request.args['query']
out = retrieve(query, 5)
return jsonify(out)

print('Starting server at %d' % args.port)
http_server = HTTPServer(WSGIContainer(app))
http_server.listen(args.port)
IOLoop.instance().start()

import time
start = time.time()
print(retrieve('When did Martin Luther die?', 10))
print(time.time() - start)


def main():
Expand Down
6 changes: 5 additions & 1 deletion squad/requirements.txt
Expand Up @@ -3,6 +3,10 @@ numpy==1.15.2
scipy==1.1.0
nltk==3.3
allennlp==0.6.1

tqdm
gensim
faiss
faiss

tornado
flask
Binary file added squad/static/files/pika.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions squad/static/files/style.css
@@ -0,0 +1,5 @@
html { position: relative; min-height: 100%; }
body { margin-bottom: 60px; }
.footer { position: absolute; bottom: 0; width: 100%; height: 40px; line-height: 15px; background-color: #f5f5f5; padding-top: 5px; font-size: 12px; text-align: center;}
label, footer { user-select: none; }
.list-group-item:first-of-type { background-color: #e0f2f1; color: #00695c; }
120 changes: 120 additions & 0 deletions squad/static/index.html
@@ -0,0 +1,120 @@
<!DOCTYPE html>
<html>
<meta charset="utf-8">
<title>PIQA Demo</title>
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.3/css/bootstrap.min.css">
<link rel="stylesheet" href="files/style.css">

<script src="https://use.fontawesome.com/releases/v5.2.0/js/all.js"></script>
<script src="https://code.jquery.com/jquery-3.3.1.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.14.3/umd/popper.min.js"></script>
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.1.3/js/bootstrap.min.js"></script>

<body>
<nav class="navbar navbar-expand-sm bg-dark navbar-dark">
<a class="navbar-brand" href="./">PIQA Demo</a>
<ul class="navbar-nav">
<li class="nav-item"> <a class="nav-link" target="_blank" href="http://aclweb.org/anthology/D18-1052">Paper</a> </li>
</ul>
</nav>
<div class="container">
<div class="mt-4 text-center">
<img src="files/pika.png" height="300" class="rounded" alt="pika" title="Ask me anything!"/>
</div>

<div class="input-group mb-1 mt-4">
<input id="question" type="text" class="form-control" placeholder="Write question" aria-label="Write question">
<div class="input-group-append">
<button id="search" class="btn btn-secondary" type="button">
<i class="fa fa-search"></i>
</button>
</div>
</div>

<div class="row">
<div id="ret-time" class="text-secondary small ml-2 col">Elapsed Time: </div>
<div class="custom-control custom-checkbox mr-3">
<input type="checkbox" class="custom-control-input small" id="realtime_chk">
<label class="custom-control-label small" for="realtime_chk">Real-time Search</label>
</div>
</div>
<hr/>

<div class="card">
<ul id="ret-results" class="list-group list-group-flush">
<li class="list-group-item"></li>
</ul>
</div>
</div>

<footer class="footer">
<div class="container">
<span class="text-muted">
<strong>Phrase-Indexed Question Answering: A New Challenge for Scalable Document Comprehension</strong><br/>
Demo page made by <a target="_blank" href="https://antest1.github.io/">Gyeongbok Lee</a>
</span>
</div>
</footer>


<script>
var tout_id = 0;
$("#realtime_chk").prop('checked', true);

$("#question").bind("input", function() {
var query = $("#question").val();
clearTimeout(tout_id);
var is_real = $("#realtime_chk").is(":checked") == true;
if (is_real) {
if (query.trim().length > 0) {
tout_id = window.setTimeout(execute, 100, query);
} else {
init_result();
}
}
});

$("#search").click(function() {
var query = $("#question").val();
if (query.trim().length > 0) {
execute(query);
} else {
init_result();
}
});

function init_result() {
$("#ret-time").text("Elapsed time: ");
$("#ret-results").html("<li class=\"list-group-item\"></li>");
}

function highlight(item) {
var text = item[0];
var start = item[1];
var end = item[2]-1;
new_text = text.split("");
new_text[start] = "<strong>" + new_text[start];
new_text[end] = new_text[end] + "</strong>";
return new_text.join("");
}

function execute(text) {
$("#ret-time").text("Elapsed time: ");
start_time = + new Date();
$.get("/api?query=" + encodeURIComponent(text), function(result) {
end_time = + new Date();
$("#ret-time").text("Elapsed time: " + (end_time - start_time) / 1000 + "s");
$("#ret-results").empty();
for (var i = 0; i < result.length; i++) {
$("#ret-results").append("<li class=\"list-group-item\"><div class=\"row\">"
+ "<div class=\"col-10\">" + highlight(result[i]) + "</div>"
+ "<div class=\"col-2 text-right\">" + result[i][3] + "</div>"
+ "</div></li>")
}
});
}
</script>

</body>

</html>

0 comments on commit 4f07a6b

Please sign in to comment.