-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathvisualizer.py
executable file
·137 lines (113 loc) · 4.4 KB
/
visualizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import shutil
from collections import OrderedDict
import http.server
import socketserver
import argparse
import json
import os
import numpy as np
from tqdm import tqdm
from jinja2 import Environment, FileSystemLoader
from basic_cnn.evaluator import get_span_score_pairs, get_best_span
def bool_(string):
if string == 'True':
return True
elif string == 'False':
return False
else:
raise Exception()
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default='basic')
parser.add_argument("--data_type", type=str, default='dev')
parser.add_argument("--step", type=int, default=5000)
parser.add_argument("--template_name", type=str, default="visualizer.html")
parser.add_argument("--num_per_page", type=int, default=100)
parser.add_argument("--data_dir", type=str, default="data/squad")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--open", type=str, default='False')
parser.add_argument("--run_id", type=str, default="0")
args = parser.parse_args()
return args
def _decode(decoder, sent):
return " ".join(decoder[idx] for idx in sent)
def accuracy2_visualizer(args):
model_name = args.model_name
data_type = args.data_type
num_per_page = args.num_per_page
data_dir = args.data_dir
run_id = args.run_id.zfill(2)
step = args.step
eval_path =os.path.join("out", model_name, run_id, "eval", "{}-{}.json".format(data_type, str(step).zfill(6)))
print("loading {}".format(eval_path))
eval_ = json.load(open(eval_path, 'r'))
_id = 0
html_dir = "/tmp/list_results%d" % _id
while os.path.exists(html_dir):
_id += 1
html_dir = "/tmp/list_results%d" % _id
if os.path.exists(html_dir):
shutil.rmtree(html_dir)
os.mkdir(html_dir)
cur_dir = os.path.dirname(os.path.realpath(__file__))
templates_dir = os.path.join(cur_dir, 'templates')
env = Environment(loader=FileSystemLoader(templates_dir))
env.globals.update(zip=zip, reversed=reversed)
template = env.get_template(args.template_name)
data_path = os.path.join(data_dir, "data_{}.json".format(data_type))
shared_path = os.path.join(data_dir, "shared_{}.json".format(data_type))
print("loading {}".format(data_path))
data = json.load(open(data_path, 'r'))
print("loading {}".format(shared_path))
shared = json.load(open(shared_path, 'r'))
rows = []
for i, (idx, yi, ypi, yp2i) in tqdm(enumerate(zip(*[eval_[key] for key in ('idxs', 'y', 'yp', 'yp2')])), total=len(eval_['idxs'])):
id_, q, rx, answers = (data[key][idx] for key in ('ids', 'q', '*x', 'answerss'))
x = shared['x'][rx[0]][rx[1]]
ques = [" ".join(q)]
para = [[word for word in sent] for sent in x]
span = get_best_span(ypi, yp2i)
ap = get_segment(para, span)
score = "{:.3f}".format(ypi[span[0][0]][span[0][1]] * yp2i[span[1][0]][span[1][1]-1])
row = {
'id': id_,
'title': "Hello world!",
'ques': ques,
'para': para,
'y': yi[0][0],
'y2': yi[0][1],
'yp': ypi,
'yp2': yp2i,
'a': answers,
'ap': ap,
'score': score
}
rows.append(row)
if i % num_per_page == 0:
html_path = os.path.join(html_dir, "%s.html" % str(i).zfill(8))
if (i + 1) % num_per_page == 0 or (i + 1) == len(eval_['y']):
var_dict = {'title': "Accuracy Visualization",
'rows': rows
}
with open(html_path, "wb") as f:
f.write(template.render(**var_dict).encode('UTF-8'))
rows = []
os.chdir(html_dir)
port = args.port
host = args.host
# Overriding to suppress log message
class MyHandler(http.server.SimpleHTTPRequestHandler):
def log_message(self, format, *args):
pass
handler = MyHandler
httpd = socketserver.TCPServer((host, port), handler)
if args.open == 'True':
os.system("open http://%s:%d" % (args.host, args.port))
print("serving at %s:%d" % (host, port))
httpd.serve_forever()
def get_segment(para, span):
return " ".join(para[span[0][0]][span[0][1]:span[1][1]])
if __name__ == "__main__":
ARGS = get_args()
accuracy2_visualizer(ARGS)