-
Notifications
You must be signed in to change notification settings - Fork 27
/
multi_turn_dialog.py
296 lines (252 loc) · 11 KB
/
multi_turn_dialog.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
"""
A module for multi turn dialog.
"""
import csv
from collections import Counter
from itertools import chain
import numpy as np
from .dataloader import BasicLanguageGeneration
from ..metric import MetricChain, MultiTurnPerplexityMetric, MultiTurnBleuCorpusMetric, \
MultiTurnDialogRecorder
# pylint: disable=W0223
class MultiTurnDialog(BasicLanguageGeneration):
r"""Base class for multi-turn dialog datasets. This is an abstract class.
Arguments:
end_token (int): the special token that stands for end. default: `4("<eot>")`
ext_vocab (list): special tokens. default: `["<pad>", "<unk>", "<go>", "<eos>", "<eot>"]`
key_name (list): name of subsets of the data. default: `["train", "dev", "test"]`
Attributes:
ext_vocab (list): special tokens, be placed at beginning of `vocab_list`.
For example: `["<pad>", "<unk>", "<go>", "<eos>", "<eot>"]`
pad_id (int): token for padding, always equal to `0`
unk_id (int): token for unknown words, always equal to `1`
go_id (int): token at the beginning of sentences, always equal to `2`
eos_id (int): token at the end of sentences, always equal to `3`
eot_id (int): token at the end of turns, always equal to `4`
key_name (list): name of subsets of the data. For example: `["train", "dev", "test"]`
all_vocab_list (list): vocabulary list of the datasets.
word2id (dict): a dict mapping tokens to index.
Maybe you want to use :meth:`sen_to_index` instead.
end_token (int): token for end. default: equals to `eot_id`
"""
def __init__(self, \
end_token=None, \
ext_vocab=None, \
key_name=None, \
):
ext_vocab = ext_vocab or ["<pad>", "<unk>", "<go>", "<eos>", "<eot>"]
self.eot_id = ext_vocab.index("<eot>")
super().__init__(end_token or self.eot_id, ext_vocab, key_name)
def get_batch(self, key, index):
'''Get a batch of specified `index`.
Arguments:
key (str): must be contained in `key_name`
index (list): a list of specified index
Returns:
(dict): A dict at least contains:
* turn_length(list): A 1-d list, the number of turns in sessions.
Size: `[batch_size]`
* sent_length(list): A 2-d non-padded list, the length of sentence in turns.
The second dimension is various in different session.
Length of outer list: `[batch_size]`
* sent(:class:`numpy.array`): A 3-d padding array containing id of words.
Only provide valid words. `unk_id` will be used if a word is not valid.
Size: `[batch_size, max(turn_length[i]), max(sent_length)]`
* sent_allwords(:class:`numpy.array`): A 3-d padding array containing id of words.
Provide both valid and invalid words.
Size: `[batch_size, max(turn_length[i]), max(sent_length)]`
See the example belows.
Examples:
>>> dataloader.get_batch('train', 1)
>>>
Todo:
* fix the missing example
'''
if key not in self.key_name:
raise ValueError("No set named %s." % key)
res = {}
res["turn_length"] = [len(self.data[key]['session'][i]) for i in index]
res["sent_length"] = []
for i in index:
sent_length = [len(sent) for sent in self.data[key]['session'][i]]
res["sent_length"].append(sent_length)
res_sent = res["sent"] = np.zeros((len(index), np.max(res['turn_length']), \
np.max(list(chain(*res['sent_length'])))), dtype=int)
for i, index_i in enumerate(index):
for j, sent in enumerate(self.data[key]['session'][index_i]):
res_sent[i, j, :len(sent)] = sent
res["sent_allwords"] = res_sent.copy()
res_sent[res_sent >= self.valid_vocab_len] = self.unk_id
return res
def multi_turn_trim_index(self, index, ignore_first_token=False):
'''Trim indexes for multi turn dialog. There will be 3 steps:
* For every turn, if there is an `<eot>`, \
find first `<eot>` and abondon words after it (included the `<eot>`).
* If the sentence after triming is empty, discard this turn and the turn after it.
* Ignore `<pad>` s at the end of every turn.
Arguments:
index (list or :class:`numpy.array`): a 2-d array of int.
Size: [turn_length, max_sent_length]
ignore_first_token (bool): if True, ignore first token of each turn (must be `<go>`).
Examples:
Todo:
* fix the missing example
'''
res = []
for turn_index in index:
if ignore_first_token:
turn_trim = self.trim_index(turn_index[1:])
else:
turn_trim = self.trim_index(turn_index)
if turn_trim:
res.append(turn_trim)
else:
break
return res
def multi_turn_sen_to_index(self, session, invalid_vocab=False):
'''Convert a session from string to index representation.
Arguments:
sen (list): a 2-d list of str, representing each token of the session.
invalid_vocab (bool): whether to provide invalid words.
If ``False``, invalid words will be trasfered to `unk_id`.
If ``True``, invalid words will using their own id.
Default: False
Examples:
Todo:
* fix the missing example
'''
if invalid_vocab:
return list(map(lambda sent: list(map( \
lambda word: self.word2id.get(word, self.unk_id), sent)), \
session))
else:
return list(map(lambda sent: list(map( \
self._valid_word2id, sent)), \
session))
def multi_turn_index_to_sen(self, index, trim=True, ignore_first_token=False):
'''Convert a session from index to string representation
Arguments:
index (list or :class:`numpy.array`): a 2-d array of int.
Size: [turn_length, max_sent_length]
trim (bool): if True, call :func:`multi_turn_trim_index` before convertion.
ignore_first_token (bool): Only works when trim=True.
If True, ignore first token of each turn (must be `<go>`).
Examples:
Todo:
* fix the missing example
'''
if trim:
index = self.multi_turn_trim_index(index, ignore_first_token=ignore_first_token)
return list(map(lambda sent: \
list(map(lambda word: self.all_vocab_list[word], sent)), \
index))
def get_teacher_forcing_metric(self, gen_prob_key="gen_prob"):
'''Get metric for teacher-forcing mode.
It contains:
* :class:`.metric.MultiTurnPerplexityMetric`
Arguments:
gen_prob_key (str): default: `gen_prob`. Refer to :class:`.metric.PerlplexityMetric`
'''
return MultiTurnPerplexityMetric(self, gen_prob_key=gen_prob_key)
def get_inference_metric(self, gen_key="gen"):
'''Get metric for inference.
It contains:
* :class:`.metric.BleuCorpusMetric`
* :class:`.metric.MultiTurnDialogRecorder`
Arguments:
gen_key (str): default: "gen". Refer to :class:`.metric.BleuCorpusMetric` or
:class:`.metric.MultiTurnDialogRecorder`
'''
metric = MetricChain()
metric.add_metric(MultiTurnBleuCorpusMetric(self, gen_key=gen_key))
metric.add_metric(MultiTurnDialogRecorder(self, gen_key=gen_key))
return metric
class UbuntuCorpus(MultiTurnDialog):
'''A dataloder for OpenSubtitles dataset.
Arguments:
file_path (str): a str indicates the dir of OpenSubtitles dataset.
min_vocab_times (int): A cut-off threshold of `UNK` tokens. All tokens appear
less than `min_vocab_times` will be replaced by `<unk>`. Default: 10.
max_sen_length (int): All sentences longer than `max_sen_length` will be shortened
to first `max_sen_length` tokens. Default: 50.
max_turn_length (int): All sessions longer than `max_turn_length` will be shortened
to first `max_turn_length` sentences. Default: 20.
invalid_vocab_times (int): A cut-off threshold of invalid tokens. All tokens appear
not less than `invalid_vocab_times` in the **whole dataset** (except valid words) will be
marked as invalid words. Otherwise, they are unknown words, both in training or
testing stages. Default: 0 (No unknown words).
Refer to :class:`.MultiTurnDialog` for attributes and methods.
Todo:
* add references
'''
def __init__(self, file_path, min_vocab_times=10, max_sen_length=50, max_turn_length=20, \
invalid_vocab_times=0):
self._file_path = file_path
self._min_vocab_times = min_vocab_times
self._max_sen_length = max_sen_length
self._max_turn_length = max_turn_length
self._invalid_vocab_times = invalid_vocab_times
super(UbuntuCorpus, self).__init__()
def _load_data(self):
r'''Loading dataset, invoked by `MultiTurnDialog.__init__`
'''
origin_data = {}
for key in self.key_name:
with open('%s/ubuntu_corpus_%s.csv' % (self._file_path, key)) as data_file:
raw_data = list(csv.reader(data_file))
head = raw_data[0]
if head[2] == 'Label':
raw_data = [d[0] + d[1] for d in raw_data[1:] if d[2] == '1.0']
else:
raw_data = [d[0] + d[1] for d in raw_data[1:]]
raw2line = lambda raw: [sent.strip().split() \
for sent in raw.strip().replace('__eou__', '<eos>').split('__eot__')]
origin_data[key] = {'session': list(map(raw2line, raw_data))}
raw_vocab_list = list(chain(*chain(*(origin_data['train']['session']))))
# Important: Sort the words preventing the index changes between different runs
vocab = sorted(Counter(raw_vocab_list).most_common(), key=lambda pair: (-pair[1], pair[0]))
left_vocab = list(filter(lambda x: x[1] >= self._min_vocab_times, vocab))
left_vocab = list(map(lambda x: x[0], left_vocab))
left_vocab.remove('<eos>')
vocab_list = self.ext_vocab + left_vocab
valid_vocab_len = len(vocab_list)
valid_vocab_set = set(vocab_list)
for key in self.key_name:
if key == 'train':
continue
raw_vocab_list.extend(list(chain(*chain(*(origin_data[key]['session'])))))
vocab = sorted(Counter(raw_vocab_list).most_common(), \
key=lambda pair: (-pair[1], pair[0]))
left_vocab = list( \
filter( \
lambda x: x[1] >= self._invalid_vocab_times and x[0] not in valid_vocab_set, \
vocab))
left_vocab = list(map(lambda x: x[0], left_vocab))
vocab_list.extend(left_vocab)
print("valid vocab list length = %d" % valid_vocab_len)
print("vocab list length = %d" % len(vocab_list))
word2id = {w: i for i, w in enumerate(vocab_list)}
line2id = lambda line: ([self.go_id] + list(\
map(lambda word: word2id.get(word, self.unk_id), line)) + \
[self.eot_id])[:self._max_sen_length]
data = {}
data_size = {}
for key in self.key_name:
data[key] = {}
data[key]['session'] = [list(map(line2id, session[:self._max_turn_length])) \
for session in origin_data[key]['session']]
data_size[key] = len(data[key]['session'])
vocab = list(chain(*chain(*(origin_data[key]['session']))))
vocab_num = len(vocab)
oov_num = len(list(filter(lambda word: word not in word2id, vocab)))
invalid_num = len(list(filter(lambda word: word not in valid_vocab_set, vocab))) - oov_num
sent_length = list(map(len, chain(*origin_data[key]['session'])))
cut_word_num = np.sum(np.maximum(np.array(sent_length) - self._max_sen_length + 2, 0))
turn_length = list(map(len, origin_data[key]['session']))
sent_num = np.sum(turn_length)
cut_sent_num = np.sum(np.maximum(np.array(turn_length) - self._max_turn_length, 0))
print(("%s set. invalid rate: %f, unknown rate: %f, max sentence length before cut: %d, " + \
"cut word rate: %f\n\tmax turn length before cut: %d, cut sentence rate: %f") % \
(key, invalid_num / vocab_num, oov_num / vocab_num, max(sent_length), \
cut_word_num / vocab_num, max(turn_length), cut_sent_num / sent_num))
return vocab_list, valid_vocab_len, data, data_size