-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathread_data.py
executable file
·294 lines (245 loc) · 11.6 KB
/
read_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import json
import os
import random
import itertools
import math
from collections import defaultdict
import numpy as np
from cnn_dm.prepro import para2sents
from my.tensorflow import grouper
from my.utils import index
class Data(object):
def get_size(self):
raise NotImplementedError()
def get_by_idxs(self, idxs):
"""
Efficient way to obtain a batch of items from filesystem
:param idxs:
:return dict: {'X': [,], 'Y', }
"""
data = defaultdict(list)
for idx in idxs:
each_data = self.get_one(idx)
for key, val in each_data.items():
data[key].append(val)
return data
def get_one(self, idx):
raise NotImplementedError()
def get_empty(self):
raise NotImplementedError()
def __add__(self, other):
raise NotImplementedError()
class MyData(Data):
def __init__(self, config, root_dir, file_names):
self.root_dir = root_dir
self.file_names = file_names
self.config = config
def get_one(self, idx):
file_name = self.file_names[idx]
with open(os.path.join(self.root_dir, file_name), 'r') as fh:
url = fh.readline().strip()
_ = fh.readline()
para = fh.readline().strip()
_ = fh.readline()
ques = fh.readline().strip()
_ = fh.readline()
answer = fh.readline().strip()
_ = fh.readline()
cands = list(line.strip() for line in fh)
cand_ents = list(cand.split(":")[0] for cand in cands)
wordss = para2sents(para, self.config.width)
ques_words = ques.split(" ")
x = wordss
cx = [[list(word) for word in words] for words in wordss]
q = ques_words
cq = [list(word) for word in ques_words]
y = answer
c = cand_ents
data = {'x': x, 'cx': cx, 'q': q, 'cq': cq, 'y': y, 'c': c, 'ids': file_name}
return data
def get_empty(self):
return MyData(self.config, self.root_dir, [])
def __add__(self, other):
file_names = self.file_names + other.file_names
return MyData(self.config, self.root_dir, file_names)
def get_size(self):
return len(self.file_names)
class DataSet(object):
def __init__(self, data, data_type, shared=None, valid_idxs=None):
self.data = data # e.g. {'X': [0, 1, 2], 'Y': [2, 3, 4]}
self.data_type = data_type
self.shared = shared
total_num_examples = self.get_data_size()
self.valid_idxs = range(total_num_examples) if valid_idxs is None else valid_idxs
self.num_examples = total_num_examples
def _sort_key(self, idx):
rx = self.data['*x'][idx]
x = self.shared['x'][rx[0]][rx[1]]
return max(map(len, x))
def get_data_size(self):
if isinstance(self.data, dict):
return len(next(iter(self.data.values())))
elif isinstance(self.data, Data):
return self.data.get_size()
raise Exception()
def get_by_idxs(self, idxs):
if isinstance(self.data, dict):
out = defaultdict(list)
for key, val in self.data.items():
out[key].extend(val[idx] for idx in idxs)
return out
elif isinstance(self.data, Data):
return self.data.get_by_idxs(idxs)
raise Exception()
def get_one(self, idx):
if isinstance(self.data, dict):
out = {key: [val[idx]] for key, val in self.data.items()}
return out
elif isinstance(self.data, Data):
return self.data.get_one(idx)
def get_batches(self, batch_size, num_batches=None, shuffle=False, cluster=False):
"""
:param batch_size:
:param num_batches:
:param shuffle:
:param cluster: cluster examples by their lengths; this might give performance boost (i.e. faster training).
:return:
"""
num_batches_per_epoch = int(math.ceil(self.num_examples / batch_size))
if num_batches is None:
num_batches = num_batches_per_epoch
num_epochs = int(math.ceil(num_batches / num_batches_per_epoch))
if shuffle:
random_idxs = random.sample(self.valid_idxs, len(self.valid_idxs))
if cluster:
sorted_idxs = sorted(random_idxs, key=self._sort_key)
sorted_grouped = lambda: list(grouper(sorted_idxs, batch_size))
grouped = lambda: random.sample(sorted_grouped(), num_batches_per_epoch)
else:
random_grouped = lambda: list(grouper(random_idxs, batch_size))
grouped = random_grouped
else:
raw_grouped = lambda: list(grouper(self.valid_idxs, batch_size))
grouped = raw_grouped
batch_idx_tuples = itertools.chain.from_iterable(grouped() for _ in range(num_epochs))
for _ in range(num_batches):
batch_idxs = tuple(i for i in next(batch_idx_tuples) if i is not None)
batch_data = self.get_by_idxs(batch_idxs)
shared_batch_data = {}
for key, val in batch_data.items():
if key.startswith('*'):
assert self.shared is not None
shared_key = key[1:]
shared_batch_data[shared_key] = [index(self.shared[shared_key], each) for each in val]
batch_data.update(shared_batch_data)
batch_ds = DataSet(batch_data, self.data_type, shared=self.shared)
yield batch_idxs, batch_ds
def get_multi_batches(self, batch_size, num_batches_per_step, num_steps=None, shuffle=False, cluster=False):
batch_size_per_step = batch_size * num_batches_per_step
batches = self.get_batches(batch_size_per_step, num_batches=num_steps, shuffle=shuffle, cluster=cluster)
multi_batches = (tuple(zip(grouper(idxs, batch_size, shorten=True, num_groups=num_batches_per_step),
data_set.divide(num_batches_per_step))) for idxs, data_set in batches)
return multi_batches
def get_empty(self):
if isinstance(self.data, dict):
data = {key: [] for key in self.data}
elif isinstance(self.data, Data):
data = self.data.get_empty()
else:
raise Exception()
return DataSet(data, self.data_type, shared=self.shared)
def __add__(self, other):
if isinstance(self.data, dict):
data = {key: val + other.data[key] for key, val in self.data.items()}
elif isinstance(self.data, Data):
data = self.data + other.data
else:
raise Exception()
valid_idxs = list(self.valid_idxs) + [valid_idx + self.num_examples for valid_idx in other.valid_idxs]
return DataSet(data, self.data_type, shared=self.shared, valid_idxs=valid_idxs)
def divide(self, integer):
batch_size = int(math.ceil(self.num_examples / integer))
idxs_gen = grouper(self.valid_idxs, batch_size, shorten=True, num_groups=integer)
data_gen = (self.get_by_idxs(idxs) for idxs in idxs_gen)
ds_tuple = tuple(DataSet(data, self.data_type, shared=self.shared) for data in data_gen)
return ds_tuple
class MyDataSet(DataSet):
def __init__(self, data, data_type, shared=None, valid_idxs=None):
super(MyDataSet, self).__init__(data, data_type, shared=shared, valid_idxs=valid_idxs)
shared['max_num_sents'] = len(self.get_one(self.num_examples-1)['x'])
def _sort_key(self, idx):
return idx
def read_data(config, data_type, ref, data_filter=None):
shared_path = os.path.join(config.data_dir, "shared_{}.json".format(data_type))
with open(shared_path, 'r') as fh:
shared = json.load(fh)
paths = shared['sorted']
if config.filter_ratio < 1.0:
stop = int(round(len(paths) * config.filter_ratio))
paths = paths[:stop]
num_examples = len(paths)
valid_idxs = range(num_examples)
print("Loaded {}/{} examples from {}".format(len(valid_idxs), num_examples, data_type))
shared_path = config.shared_path or os.path.join(config.out_dir, "shared.json")
if not ref:
word2vec_dict = shared['lower_word2vec'] if config.lower_word else shared['word2vec']
word_counter = shared['lower_word_counter'] if config.lower_word else shared['word_counter']
char_counter = shared['char_counter']
if config.finetune:
shared['word2idx'] = {word: idx + 3 for idx, word in
enumerate(word for word, count in word_counter.items()
if count > config.word_count_th or (config.known_if_glove and word in word2vec_dict))}
else:
assert config.known_if_glove
assert config.use_glove_for_unk
shared['word2idx'] = {word: idx + 3 for idx, word in
enumerate(word for word, count in word_counter.items()
if count > config.word_count_th and word not in word2vec_dict)}
shared['char2idx'] = {char: idx + 2 for idx, char in
enumerate(char for char, count in char_counter.items()
if count > config.char_count_th)}
NULL = "-NULL-"
UNK = "-UNK-"
ENT = "-ENT-"
shared['word2idx'][NULL] = 0
shared['word2idx'][UNK] = 1
shared['word2idx'][ENT] = 2
shared['char2idx'][NULL] = 0
shared['char2idx'][UNK] = 1
json.dump({'word2idx': shared['word2idx'], 'char2idx': shared['char2idx']}, open(shared_path, 'w'))
else:
new_shared = json.load(open(shared_path, 'r'))
for key, val in new_shared.items():
shared[key] = val
if config.use_glove_for_unk:
# create new word2idx and word2vec
word2vec_dict = shared['lower_word2vec'] if config.lower_word else shared['word2vec']
new_word2idx_dict = {word: idx for idx, word in enumerate(word for word in word2vec_dict.keys() if word not in shared['word2idx'])}
shared['new_word2idx'] = new_word2idx_dict
offset = len(shared['word2idx'])
word2vec_dict = shared['lower_word2vec'] if config.lower_word else shared['word2vec']
new_word2idx_dict = shared['new_word2idx']
idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()}
# print("{}/{} unique words have corresponding glove vectors.".format(len(idx2vec_dict), len(word2idx_dict)))
new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32')
shared['new_emb_mat'] = new_emb_mat
data = MyData(config, os.path.join(config.root_dir, data_type), paths)
data_set = MyDataSet(data, data_type, shared=shared, valid_idxs=valid_idxs)
return data_set
def get_cnn_data_filter(config):
return True
def update_config(config, data_sets):
config.max_num_sents = 0
config.max_sent_size = 0
config.max_ques_size = 0
config.max_word_size = 0
for data_set in data_sets:
shared = data_set.shared
config.max_sent_size = max(config.max_sent_size, shared['max_sent_size'])
config.max_ques_size = max(config.max_ques_size, shared['max_ques_size'])
config.max_word_size = max(config.max_word_size, shared['max_word_size'])
config.max_num_sents = max(config.max_num_sents, shared['max_num_sents'])
config.max_word_size = min(config.max_word_size, config.word_size_th)
config.char_vocab_size = len(data_sets[0].shared['char2idx'])
config.word_emb_size = len(next(iter(data_sets[0].shared['word2vec'].values())))
config.word_vocab_size = len(data_sets[0].shared['word2idx'])