/
basic_decoder.py
277 lines (241 loc) · 13.5 KB
/
basic_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastai.core import V, to_gpu
from quicknlp.utils import assert_dims, RandomUniform
def repeat_cell_state(hidden, num_beams):
results = []
for row in hidden:
if isinstance(row, (list, tuple)):
state = (row[0].repeat(1, num_beams, 1), row[1].repeat(1, num_beams, 1))
else:
state = row.repeat(1, num_beams, 1)
results.append(state)
return results
def reshape_parent_indices(indices, bs, num_beams):
parent_indices = V((torch.arange(end=bs) * num_beams).unsqueeze_(1).repeat(1, num_beams).view(-1).long())
return indices + parent_indices
def select_hidden_by_index(hidden, indices):
if hidden is None:
return hidden
results = []
for row in hidden:
if isinstance(row, (list, tuple)):
state = (torch.index_select(row[0], 1, indices), torch.index_select(row[1], 1, indices))
else:
state = torch.index_select(row, 1, indices)
results.append(state)
return results
class Decoder(nn.Module):
MAX_STEPS_ALLOWED = 320
def __init__(self, decoder_layer, projection_layer, max_tokens, eos_token, pad_token,
embedding_layer: torch.nn.Module):
super().__init__()
self.decoder_layer = decoder_layer
self.nlayers = decoder_layer.nlayers
self.projection_layer = projection_layer
self.bs = 1
self.max_iterations = max_tokens
self.eos_token = eos_token
self.pad_token = pad_token
self.beam_outputs = None
self.embedding_layer = embedding_layer
self.emb_size = embedding_layer.emb_size
self.pr_force = 0.0
self.random = RandomUniform()
def reset(self, bs):
self.decoder_layer.reset(bs)
def forward(self, inputs, hidden=None, num_beams=0, constraints=None):
self.bs = inputs.size(1)
if num_beams == 0: # zero beams, a.k.a. teacher forcing
return self._train_forward(inputs, hidden, constraints)
elif num_beams == 1: # one beam a.k.a. greedy search
return self._greedy_forward(inputs, hidden, constraints)
elif num_beams > 1: # multiple beams a.k.a topk search
return self._beam_forward(inputs, hidden, num_beams, constraints)
def _beam_forward(self, inputs, hidden, num_beams, constraints=None):
return self._topk_forward(inputs, hidden, num_beams, constraints)
def _train_forward(self, inputs, hidden=None, constraints=None):
inputs = self.embedding_layer(inputs)
if constraints is not None:
# constraint should have dim [1, bs, hd]
# and inputs should be [sl,bs,hd]
inputs = torch.cat([inputs, constraints.repeat(inputs.size(0), 1, 1)], dim=-1)
# outputs are the outputs of every layer
outputs = self.decoder_layer(inputs, hidden)
# we project only the output of the last layer
outputs = self.projection_layer(outputs[-1]) if self.projection_layer is not None else outputs[-1]
return outputs
def _greedy_forward(self, inputs, hidden=None, constraints=None):
dec_inputs = inputs
max_iterations = min(dec_inputs.size(0), self.MAX_STEPS_ALLOWED) if self.training else self.max_iterations
inputs = V(inputs[:1].data) # inputs should be only first token initially [1,bs]
sl, bs = inputs.size()
finished = to_gpu(torch.zeros(bs).byte())
iteration = 0
self.beam_outputs = inputs.clone()
final_outputs = []
while not finished.all() and iteration < max_iterations:
# output should be List[[sl, bs, layer_dim], ...] sl should be one
if 0 < iteration and self.training and 0. < self.random() < self.pr_force:
inputs = dec_inputs[iteration].unsqueeze(0)
output = self.forward(inputs, hidden=hidden, num_beams=0, constraints=constraints)
hidden = self.decoder_layer.hidden
final_outputs.append(output) # dim should be [sl=1, bs, nt]
# inputs are the indices dims [1,bs] # repackage the var to avoid grad backwards
inputs = assert_dims(V(output.data.max(dim=-1)[1]), [1, bs])
iteration += 1
self.beam_outputs = assert_dims(torch.cat([self.beam_outputs, inputs], dim=0), [iteration + 1, bs])
new_finished = inputs.data == self.eos_token
finished = finished | new_finished
# stop if the output is to big to fit in memory
self.beam_outputs = self.beam_outputs.view(-1, bs, 1)
# outputs should be [sl, bs, nt]
outputs = torch.cat(final_outputs, dim=0)
return outputs
def _topk_forward(self, inputs, hidden, num_beams, constraints=None):
sl, bs = inputs.size()
# initial logprobs should be zero (pr of <sos> token in the start is 1)
logprobs = torch.zeros_like(inputs[:1]).view(1, bs, 1).float() # shape will be [sl, bs, 1]
inputs = inputs[:1].repeat(1, num_beams) # inputs should be only first token initially [1,bs x num_beams]
finished = to_gpu(torch.zeros(bs * num_beams).byte())
iteration = 0
final_outputs = []
self.beam_outputs = inputs.clone()
hidden = repeat_cell_state(hidden, num_beams)
while not finished.all() and iteration < self.max_iterations:
# output should be List[[sl, bs * num_beams, layer_dim], ...] sl should be one
output = self.forward(inputs, hidden=hidden, num_beams=0, constraints=constraints)
hidden = self.decoder_layer.hidden
final_outputs.append(output)
# we take the output of the last layer with dims [1, bs, output_dim]
# and get the indices of th top k for every bs
new_logprobs = F.log_softmax(output, dim=-1) # [1, bs x num_beams, nt]
num_tokens = new_logprobs.size(2)
new_logprobs = new_logprobs.view(1, bs, num_beams, num_tokens) + logprobs.unsqueeze(-1) # [1, bs, nb, nt]
# mask logprogs accordingly
new_logprobs = self.mask_logprobs(bs, finished, iteration, logprobs, new_logprobs, num_beams, num_tokens)
# TODO implement stochastic beam search
# get the top logprobs and their indices
logprobs, beams = torch.topk(new_logprobs, k=num_beams, dim=-1) # [1, bs, num_beams]
parents = beams / num_tokens
inputs = beams % num_tokens
parent_indices = reshape_parent_indices(parents.view(-1), bs=bs, num_beams=num_beams)
self.decoder_layer.hidden = select_hidden_by_index(self.decoder_layer.hidden, indices=parent_indices)
finished = torch.index_select(finished, 0, parent_indices.data)
inputs = inputs.view(1, -1).contiguous()
self.beam_outputs = torch.index_select(self.beam_outputs, dim=1, index=parent_indices)
self.beam_outputs = torch.cat([self.beam_outputs, inputs], dim=0)
new_finished = (inputs.data == self.eos_token).view(-1)
finished = finished | new_finished
iteration += 1
self.beam_outputs = self.beam_outputs.view(-1, bs, num_beams)
# ensure the outputs is the output of the last layer [sl,bs, nt]
outputs = torch.cat(final_outputs, dim=0)
return outputs
def mask_logprobs(self, bs, finished, iteration, logprobs, new_logprobs, num_beams, num_tokens):
if iteration == 0:
# only the first beam is considered in the first step, otherwise we would get the same result for every beam
new_logprobs = new_logprobs[..., 0, :]
else:
# we have to cater for finished beams as well
# create a mask [1, bs x nb, nt] with - inf everywhere
mask = torch.zeros_like(new_logprobs).fill_(-1e32).view(1, bs * num_beams, num_tokens)
f = V(finished.unsqueeze(0))
# set the pad_token position to the last logprob for the finished ones
mask[..., self.pad_token] = logprobs.view(1, bs * num_beams)
# mask shape = [1, bs * nb (that are finished), nt]
mask = mask.masked_select(f.unsqueeze(-1)).view(1, -1, num_tokens)
# replace the rows of the finished ones with the mask
new_logprobs.masked_scatter_(f.view(1, bs, num_beams, 1), mask)
# flatten all beams with the tokens
new_logprobs = new_logprobs.view(1, bs, -1)
return new_logprobs
@property
def hidden(self):
return self.decoder_layer.hidden
@hidden.setter
def hidden(self, value):
self.decoder_layer.hidden = value
@property
def layers(self):
return self.decoder_layer.layers
@property
def output_size(self):
return self.projection_layer.output_size if self.projection_layer is not None else self.decoder_layer.output_size
class TransformerDecoder(Decoder):
def __init__(self, decoder_layer, projection_layer, max_tokens, eos_token, pad_token,
embedding_layer: torch.nn.Module):
super().__init__(decoder_layer=decoder_layer, projection_layer=projection_layer, max_tokens=max_tokens,
eos_token=eos_token, pad_token=pad_token, embedding_layer=embedding_layer)
def _train_forward(self, inputs, hidden=None, constraints=None):
inputs = self.embedding_layer(inputs)
# outputs are the outputs of every layer
outputs = self.decoder_layer(inputs, hidden)
# we project only the output of the last layer
outputs = self.projection_layer(outputs[-1]) if self.projection_layer is not None else outputs[-1]
return outputs
def _greedy_forward(self, inputs, hidden=None, constraints=None):
inputs = inputs[:1] # inputs should be only first token initially [1,bs]
sl, bs = inputs.size()
finished = to_gpu(torch.zeros(bs).byte())
iteration = 0
self.beam_outputs = inputs.clone().cpu()
final_outputs = []
while not finished.all() and iteration < self.max_iterations:
# output should be List[[sl, bs, layer_dim], ...] sl should be one
# step_inputs should be [1, bs]
output = self.forward(inputs, hidden=hidden, num_beams=0)
final_outputs.append(output[-1:])
iteration += 1
step_inputs = assert_dims(V(output[-1:].data.max(dim=-1)[1]), [1, bs])
self.beam_outputs = assert_dims(torch.cat([self.beam_outputs, step_inputs.cpu()], dim=0),
[iteration + 1, bs])
new_finished = step_inputs.data == self.eos_token
inputs = torch.cat([inputs, step_inputs], dim=0)
assert_dims(inputs, [iteration + 1, bs])
finished = finished | new_finished
self.beam_outputs = self.beam_outputs.view(-1, bs, 1)
outputs = torch.cat(final_outputs, dim=0)
return outputs
def _topk_forward(self, inputs, hidden, num_beams, constraints=None):
sl, bs = inputs.size()
# initial logprobs should be zero (pr of <sos> token in the start is 1)
logprobs = torch.zeros_like(inputs[:1]).view(1, bs, 1).float() # shape will be [sl, bs, 1]
inputs = inputs[:1].repeat(1,
num_beams) # inputs should be only first token initially [1,bs x num_beams]
finished = to_gpu(torch.zeros(bs * num_beams).byte())
iteration = 0
self.beam_outputs = inputs.clone().cpu()
hidden = repeat_cell_state(hidden, num_beams)
final_outputs = []
while not finished.all() and iteration < self.max_iterations:
# output should be List[[sl, bs * num_beams, layer_dim], ...] sl should be one
output = self.forward(inputs, hidden=hidden, num_beams=0)
step_prediction = output[-1:] # [sl, bs* num_beams , ntokens]
final_outputs.append(step_prediction.cpu())
# we take the output of the last layer with dims [1, bs, output_dim]
# and get the indices of th top k for every bs
new_logprobs = F.log_softmax(step_prediction, dim=-1) # [1, bs x num_beams, nt]
num_tokens = new_logprobs.size(2)
new_logprobs = new_logprobs.view(1, bs, num_beams, num_tokens) + logprobs.unsqueeze(-1) # [1, bs, nb, nt]
# mask logprobs if they are finished or it's the first iteration
new_logprobs = self.mask_logprobs(bs, finished, iteration, logprobs, new_logprobs, num_beams, num_tokens)
# TODO take into account sequence_length for getting the top logprobs and their indices
logprobs, beams = torch.topk(new_logprobs, k=num_beams, dim=-1) # [1, bs, num_beams]
parents = beams / num_tokens
step_inputs = beams % num_tokens
parent_indices = reshape_parent_indices(parents.view(-1), bs=bs, num_beams=num_beams)
finished = torch.index_select(finished, 0, parent_indices.data)
step_inputs = step_inputs.view(1, -1).contiguous()
new_finished = (step_inputs.data == self.eos_token).view(-1)
inputs = torch.index_select(inputs, dim=1, index=parent_indices)
inputs = torch.cat([inputs, step_inputs], dim=0)
finished = finished | new_finished
iteration += 1
self.beam_outputs = torch.index_select(self.beam_outputs, dim=1, index=parent_indices.cpu())
self.beam_outputs = torch.cat([self.beam_outputs, step_inputs.cpu()], dim=0)
# ensure the outputs is the output of the last layer [sl,bs, nt]
outputs = torch.cat(final_outputs, dim=0)
self.beam_outputs = self.beam_outputs.view(-1, bs, num_beams)
return outputs