This repository has been archived by the owner on Aug 1, 2023. It is now read-only.
/
beam_decode.py
1031 lines (935 loc) · 42.1 KB
/
beam_decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
import math
from typing import List, Tuple
import torch
from fairseq import search, utils
from fairseq.models import FairseqIncrementalDecoder
from pytorch_translate import utils as pytorch_translate_utils
from torch import Tensor
class SequenceGenerator(object):
def __init__(
self,
models,
tgt_dict,
beam_size=1,
minlen=1,
maxlen=None,
stop_early=True,
normalize_scores=True,
len_penalty=0,
unk_reward=0,
lexicon_reward=0,
retain_dropout=False,
word_reward=0,
model_weights=None,
use_char_source=False,
diverse_beam_groups=-1,
diverse_beam_strength=0.5,
diversity_sibling_gamma=0.0,
sampling=False,
sampling_topk=-1,
temperature=1,
):
"""Generates translations of a given source sentence.
Args:
models: List of FairseqEncoderDecoderModel objects. Each one must
implement reorder_encoder_output() method to replicate encoder
outputs.
min/maxlen: The length of the generated output will be bounded by
minlen and maxlen (not including the end-of-sentence marker).
stop_early: Stop generation immediately after we finalize beam_size
hypotheses, even though longer hypotheses might have better
normalized scores.
normalize_scores: Normalize scores by the length of the output.
word_reward: add this value to score each token except EOS
(an alternative method to len_penalty for encouraging longer
output)
model_weights: None or list of Python floats of the same length as
`models` with ensemble interpolation weights.
use_char_source: if True, encoder inputs consist of (src_tokens,
src_lengths, char_inds, word_lengths)
diverse_beam_groups: number of groups for Diverse Beam Search
(-1 by default is vanilla beam search)
diverse_beam_strength: strength of diversity penalty for Diverse
Beam Search.
diversity_sibling_gamma: The diversity rate of sibling rank (-0.0 by default
to disable sibling rank penalty)
sampling (bool, optional): sample outputs instead of beam search
(default: False)
sampling_topk (int, optional): only sample among the top-k choices
at each step (default: -1)
temperature (float, optional): temperature, where values
>1.0 produce more uniform samples and values <1.0 produce
sharper samples (default: 1.0)
"""
self.models = models
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
self.minlen = minlen
max_decoder_len = min(m.max_decoder_positions() for m in self.models)
self.maxlen = (
max_decoder_len if maxlen is None else min(maxlen, max_decoder_len)
)
self.stop_early = stop_early
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.unk_reward = unk_reward
self.lexicon_reward = lexicon_reward
self.lexicon_indices = tgt_dict.lexicon_indices_list()
self.retain_dropout = retain_dropout
self.temperature = temperature
self.word_reward = word_reward
if model_weights is not None:
assert len(models) == len(model_weights)
self.model_weights = model_weights
else:
self.model_weights = [1.0 / len(models)] * len(models)
self.use_char_source = use_char_source
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
assert temperature > 0, "--temperature must be greater than 0"
if sampling:
self.search = search.Sampling(tgt_dict, sampling_topk)
elif diverse_beam_groups > 0:
self.search = search.DiverseBeamSearch(
tgt_dict, diverse_beam_groups, diverse_beam_strength
)
else:
self.search = search.BeamSearch(tgt_dict)
self.diversity_sibling_gamma = diversity_sibling_gamma
def cuda(self):
for model in self.models:
model.cuda()
return self
def generate_batched_itr(
self,
data_itr,
beam_size=None,
maxlen_a=0.0,
maxlen_b=None,
cuda=False,
timer=None,
prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda: use GPU for generation
timer: StopwatchMeter for timing generations.
"""
if maxlen_b is None:
maxlen_b = self.maxlen
for sample in data_itr:
if "net_input" not in sample:
continue
if cuda:
s = utils.move_to_cuda(sample)
else:
s = sample
input = s["net_input"]
srclen = input["src_tokens"].size(1)
if self.use_char_source:
encoder_input = {
k: v
for k, v in input.items()
if k in ["src_tokens", "src_lengths", "char_inds", "word_lengths"]
}
else:
encoder_input = {
k: v for k, v in input.items() if k in ["src_tokens", "src_lengths"]
}
if timer is not None:
timer.start()
with torch.no_grad():
hypos = self.generate(
encoder_input=encoder_input,
beam_size=beam_size,
maxlen=int(maxlen_a * srclen + maxlen_b),
prefix_tokens=s["target"][:, :prefix_size]
if prefix_size > 0
else None,
)
if timer is not None:
timer.stop(s["ntokens"])
for i, id in enumerate(s["id"]):
# remove padding
src = utils.strip_pad(input["src_tokens"][i, :], self.pad)
ref = utils.strip_pad(s["target"][i, :], self.pad)
yield id, src, ref, hypos[i]
@torch.no_grad()
def generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
encoder_inputs = self.prepare_encoder_inputs(encoder_input)
encoder_outs, incremental_states = self._encode(encoder_input=encoder_inputs)
return self._decode_target(
encoder_input,
encoder_outs,
incremental_states,
self.diversity_sibling_gamma,
beam_size,
maxlen,
prefix_tokens,
)
def prepare_encoder_inputs(self, encoder_input):
if self.use_char_source:
encoder_inputs = (
encoder_input["src_tokens"],
encoder_input["src_lengths"],
encoder_input["char_inds"],
encoder_input["word_lengths"],
)
else:
encoder_inputs = (encoder_input["src_tokens"], encoder_input["src_lengths"])
return encoder_inputs
def _build_constraints(self, src_tokens, beam_size):
"""
Stub functions for adding application specific constraint checks on
the candidates being generated during beam search. This and the below
stub functions can be implemented in a child class without needing to
touch the actual beam search code
"""
pass
def _apply_constraint_penalty(self, scores):
pass
def _update_constraints(self, constraints, next_tokens, idx):
pass
def _reorder_constraints(self, constraints, new_indices):
pass
def _apply_eos_constraints(self, constraints, eos_bbsz_idx, eos_scores):
pass
def _finalize_constrained_results(self, finalized, device):
pass
def _decode_target(
self,
encoder_input,
encoder_outs,
incremental_states,
diversity_sibling_gamma=0.0,
beam_size=None,
maxlen=None,
prefix_tokens=None,
):
src_tokens_tensor = pytorch_translate_utils.get_source_tokens_tensor(
encoder_input["src_tokens"]
)
beam_size = beam_size if beam_size is not None else self.beam_size
bsz = src_tokens_tensor.size(0)
reorder_indices = (
torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1).long()
)
for i, model in enumerate(self.models):
encoder_outs[i] = model.encoder.reorder_encoder_out(
encoder_out=encoder_outs[i],
new_order=reorder_indices.type_as(src_tokens_tensor),
)
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
# initialize buffers
scores = src_tokens_tensor.new(bsz * beam_size, maxlen + 1).float().fill_(0)
scores_buf = scores.clone()
tokens = src_tokens_tensor.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos
# may differ from input length
if isinstance(encoder_outs[0], (list, tuple)):
src_encoding_len = encoder_outs[0][0].size(0)
elif isinstance(encoder_outs[0], dict):
if isinstance(encoder_outs[0]["encoder_out"], tuple):
# Fairseq compatibility
src_encoding_len = encoder_outs[0]["encoder_out"][0].size(1)
else:
src_encoding_len = encoder_outs[0]["encoder_out"].size(0)
attn = scores.new(bsz * beam_size, src_encoding_len, maxlen + 2)
attn_buf = attn.clone()
# list of completed sentences
finalized = [[] for i in range(bsz)]
finished = [False for i in range(bsz)]
worst_finalized = [{"idx": None, "score": -math.inf} for i in range(bsz)]
num_remaining_sent = bsz
# number of candidate hypos per step
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
cand_offsets = torch.arange(0, cand_size).type_as(tokens)
# helper function for allocating buffers on the fly
buffers = {}
# init constraints
constraints = self._build_constraints(src_tokens_tensor, beam_size)
def buffer(name, type_of=tokens): # noqa
if name not in buffers:
buffers[name] = type_of.new()
return buffers[name]
def is_finished(sent, step, unfinalized_scores=None):
"""
Check whether we've finished generation for a given sentence, by
comparing the worst score among finalized hypotheses to the best
possible score among unfinalized hypotheses.
"""
assert len(finalized[sent]) <= beam_size
if len(finalized[sent]) == beam_size:
if self.stop_early or step == maxlen or unfinalized_scores is None:
return True
# stop if the best unfinalized score is worse than the worst
# finalized one
best_unfinalized_score = unfinalized_scores[sent].max()
if self.normalize_scores:
best_unfinalized_score /= (maxlen + 1) ** self.len_penalty
if worst_finalized[sent]["score"] >= best_unfinalized_score:
return True
return False
def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
that appear later.
Args:
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
indicating which hypotheses to finalize
eos_scores: A vector of the same size as bbsz_idx containing
scores for each hypothesis
unfinalized_scores: A vector containing scores for all
unfinalized hypotheses
"""
assert bbsz_idx.numel() == eos_scores.numel()
# clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[
:, 1 : step + 2
] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
# compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
pos_scores[:, step] = eos_scores
# convert from cumulative to per-position scores
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
# normalize sentence-level scores
if self.normalize_scores:
eos_scores /= (step + 1) ** self.len_penalty
sents_seen = set()
for i, (idx, score) in enumerate(
zip(bbsz_idx.tolist(), eos_scores.tolist())
):
sent = idx // beam_size
sents_seen.add(sent)
def get_hypo():
_, alignment = attn_clone[i].max(dim=0)
return {
"tokens": tokens_clone[i],
"score": score,
"attention": attn_clone[i], # src_len x tgt_len
"alignment": alignment,
"positional_scores": pos_scores[i],
}
if len(finalized[sent]) < beam_size:
finalized[sent].append(get_hypo())
elif not self.stop_early and score > worst_finalized[sent]["score"]:
# replace worst hypo for this sentence with new/better one
worst_idx = worst_finalized[sent]["idx"]
if worst_idx is not None:
finalized[sent][worst_idx] = get_hypo()
# find new worst finalized hypo for this sentence
idx, s = min(
enumerate(finalized[sent]), key=lambda r: r[1]["score"]
)
worst_finalized[sent] = {"score": s["score"], "idx": idx}
# return number of hypotheses finished this step
num_finished = 0
for sent in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfinalized_scores):
finished[sent] = True
num_finished += 1
return num_finished
reorder_state = None
possible_translation_tokens = None
for step in range(maxlen + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
for model in self.models:
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(
incremental_states[model], reorder_state
)
# Run decoder for one step
logprobs, avg_attn, possible_translation_tokens = self._decode(
tokens[:, : step + 1],
encoder_outs,
incremental_states,
possible_translation_tokens,
)
logprobs[:, self.pad] = -math.inf # never select pad
# apply unk reward
if possible_translation_tokens is None:
# No vocab reduction, so unk is represented by self.unk at
# position self.unk
unk_index = self.unk
logprobs[:, unk_index] += self.unk_reward
else:
# When we use vocab reduction, the token value self.unk may not
# be at the position self.unk, but somewhere else in the list
# of possible_translation_tokens. It's also possible not to
# show up in possible_translation_tokens at all, meaning we
# can't generate an unk.
unk_pos = torch.nonzero(possible_translation_tokens == self.unk)
if unk_pos.size()[0] != 0:
# only add unk_reward if unk index appears in
# possible_translation_tokens
unk_index = unk_pos[0][0]
logprobs[:, unk_index] += self.unk_reward
# external lexicon reward
logprobs[:, self.lexicon_indices] += self.lexicon_reward
logprobs += self.word_reward
logprobs[:, self.eos] -= self.word_reward
# Record attention scores
if avg_attn is not None:
attn[:, :, step + 1].copy_(avg_attn)
cand_scores = buffer("cand_scores", type_of=scores)
cand_indices = buffer("cand_indices")
cand_beams = buffer("cand_beams")
eos_bbsz_idx = buffer("eos_bbsz_idx")
eos_scores = buffer("eos_scores", type_of=scores)
scores = scores.type_as(logprobs)
scores_buf = scores_buf.type_as(logprobs)
if step < maxlen:
self._apply_constraint_penalty(scores) # stub call
if prefix_tokens is not None and step < prefix_tokens.size(1):
logprobs_slice = logprobs.view(bsz, -1, logprobs.size(-1))[:, 0, :]
cand_scores = torch.gather(
logprobs_slice, dim=1, index=prefix_tokens[:, step].view(-1, 1)
).expand(-1, cand_size)
cand_indices = (
prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size)
)
cand_beams.resize_as_(cand_indices).fill_(0)
else:
possible_tokens_size = self.vocab_size
if possible_translation_tokens is not None:
possible_tokens_size = possible_translation_tokens.size(0)
if diversity_sibling_gamma > 0:
logprobs = self.diversity_sibling_rank(
logprobs.view(bsz, -1, possible_tokens_size),
diversity_sibling_gamma,
)
cand_scores, cand_indices, cand_beams = self.search.step(
step,
logprobs.view(bsz, -1, possible_tokens_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
)
# vocabulary reduction
if possible_translation_tokens is not None:
possible_translation_tokens = possible_translation_tokens.view(
1, possible_tokens_size
).expand(cand_indices.size(0), possible_tokens_size)
cand_indices = torch.gather(
possible_translation_tokens,
dim=1,
index=cand_indices,
out=cand_indices,
)
else:
# finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest log prob of EOS right now
logprobs.add_(scores[:, step - 1].view(-1, 1))
torch.sort(
logprobs[:, self.eos],
descending=True,
out=(eos_scores, eos_bbsz_idx),
)
num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores)
assert num_remaining_sent == 0
break
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add_(bbsz_offsets)
# finalize hypotheses that end in eos
eos_mask = cand_indices.eq(self.eos)
if step >= self.minlen:
# only consider eos when it's among the top beam_size indices
torch.masked_select(
cand_bbsz_idx[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_bbsz_idx,
)
if eos_bbsz_idx.numel() > 0:
torch.masked_select(
cand_scores[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_scores,
)
self._apply_eos_constraints(constraints, eos_bbsz_idx, eos_scores)
num_remaining_sent -= finalize_hypos(
step, eos_bbsz_idx, eos_scores, cand_scores
)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
assert step < maxlen
# set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
active_mask = buffer("active_mask")
torch.add(
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[: eos_mask.size(1)],
out=active_mask,
)
# get the top beam_size active hypotheses, which are just the hypos
# with the smallest values in active_mask
active_hypos, _ignore = buffer("active_hypos"), buffer("_ignore")
torch.topk(
active_mask,
k=beam_size,
dim=1,
largest=False,
out=(_ignore, active_hypos),
)
active_bbsz_idx = buffer("active_bbsz_idx")
torch.gather(cand_bbsz_idx, dim=1, index=active_hypos, out=active_bbsz_idx)
active_scores = torch.gather(
cand_scores,
dim=1,
index=active_hypos,
out=scores[:, step].view(bsz, beam_size),
)
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
# copy tokens and scores for active hypotheses
torch.index_select(
tokens[:, : step + 1],
dim=0,
index=active_bbsz_idx,
out=tokens_buf[:, : step + 1],
)
torch.gather(
cand_indices,
dim=1,
index=active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
)
# update constraints for next step
constraints = self._reorder_constraints(constraints, active_bbsz_idx)
self._update_constraints(constraints, tokens_buf[:, step + 1], step)
if step > 0:
torch.index_select(
scores[:, :step],
dim=0,
index=active_bbsz_idx,
out=scores_buf[:, :step],
)
torch.gather(
cand_scores,
dim=1,
index=active_hypos,
out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
)
# copy attention for active hypotheses
torch.index_select(
attn[:, :, : step + 2],
dim=0,
index=active_bbsz_idx,
out=attn_buf[:, :, : step + 2],
)
# swap buffers
tokens, tokens_buf = tokens_buf, tokens
scores, scores_buf = scores_buf, scores
attn, attn_buf = attn_buf, attn
# reorder incremental state in decoder
reorder_state = active_bbsz_idx
# sort by score descending
for sent in range(bsz):
finalized[sent] = sorted(
finalized[sent], key=lambda r: r["score"], reverse=True
)
self._finalize_constrained_results(finalized, scores.device)
return finalized
def _encode(self, encoder_input):
encoder_outs = []
incremental_states = {}
for model in self.models:
if not self.retain_dropout:
model.eval()
if isinstance(model.decoder, FairseqIncrementalDecoder):
incremental_states[model] = {}
else:
incremental_states[model] = None
encoder_out = model.encoder(*encoder_input)
encoder_outs.append(encoder_out)
return encoder_outs, incremental_states
@staticmethod
def gather_probs(all_translation_tokens, all_probs):
"""
Maps probabilities for multiple models with different output softmax
dimensions to the same combined token space. This is a simplified
example, normally probs would be in log space and would be size
[bsz, len(possible_translation_tokens)]
Model 1:
possible_translation_tokens: [3, 7, 8, 9]
probs: [0.25, 0.25, 0.25, 0.25]
Model 2:
possible_translation_tokens: [0, 3, 5]
probs: [0.4, 0.5, 0.1]
all_translation_tokens: [[3, 7, 8, 9], [0, 3, 5]]
all_probs: [[0.25, 0.25, 0.25, 0.25], [0.4, 0.5, 0.1]]
possible_translation_tokens = [0, 3, 5, 7, 8, 9] (order varies)
mapped_probs for model 1: [0 , 0.25, 0 , 0.25, 0.25, 0.25]
mapped_probs for model 2: [0.4, 0.5 , 0.1, 0 , 0 , 0]
avg_probs = [0.4, 0.75, 0.1, 0.25, 0.25, 0.25] (order varies but
corresponds to possible_translation_tokens)
Inputs:
all_translation_tokens: List[Optional[possible_translation_tokens]]
where possible_translation_tokens is a flat Tensor representing
the possible translation tokens from model output. Note that the
possible_translation_tokens will be None only if vocab reduction
was not used.
all_probs: List[probs] where probs is a flat Tensor of normalized
probs for each model output. If vocab reduction was not used,
each probs list will be of length vocab size. Otherwise, each
probs will be the same length as that model's
possible_translation_tokens
Returns:
avg_probs: average probabilities of tokens from a merged list of
possible_translation_tokens from every model.
possible_translation_tokens: merged list of
possible_translation_tokens from every model.
"""
assert len(all_translation_tokens) == len(all_probs), (
f"Number of possible_translation_tokens tensors in "
f"all_translation_tokens list -- got length "
f"{len(all_translation_tokens)} -- should match the number of "
f"probs tensors in all_probs list -- got length {len(all_probs)}.\n"
f"all_translation_tokens: {all_translation_tokens}\n"
f"all_probs: {all_probs}"
)
possible_translation_tokens = None
inv_indices_per_model = [None] * len(all_translation_tokens)
if all_translation_tokens[0] is not None:
# Get unique translation tokens out of all the
# possible_translation_tokens for every model.
# inverse indices for the example above: [5, 4, 2, 1, 3, 5, 0]
possible_translation_tokens, inverse_indices = torch.unique(
torch.cat(all_translation_tokens, dim=0),
sorted=False,
return_inverse=True,
)
# softmax_sizes for the example above: [4, 3]
softmax_sizes = [
translation_tokens.size(0)
for translation_tokens in all_translation_tokens
]
inv_indices_per_model = torch.split(
inverse_indices, split_size_or_sections=softmax_sizes
)
avg_probs = None
for inv_ind, probs in zip(inv_indices_per_model, all_probs):
mapped_probs = probs
if possible_translation_tokens is not None:
# The corresponding model did not use vocab reduction if
# possible_translation_tokens is None.
mapped_probs = torch.zeros(
(probs.size(0), possible_translation_tokens.size(0)),
device=probs.device,
)
mapped_probs[:, inv_ind] = probs
if avg_probs is None:
avg_probs = mapped_probs
else:
avg_probs.add_(mapped_probs)
return avg_probs, possible_translation_tokens
def _decode(
self, tokens, encoder_outs, incremental_states, possible_translation_tokens=None
):
avg_attn = None
all_translation_tokens = []
all_log_probs = []
for model_weight, model, encoder_out in zip(
self.model_weights, self.models, encoder_outs
):
with torch.no_grad():
if (
possible_translation_tokens is not None
and len(possible_translation_tokens.shape) > 1
):
# reverse beam replication
possible_translation_tokens = possible_translation_tokens[0]
decoder_out = list(
model.decoder(
tokens,
encoder_out,
incremental_states[model],
possible_translation_tokens=possible_translation_tokens,
)
)
decoder_out[0] = decoder_out[0][:, -1, :]
if self.temperature != 1.0:
decoder_out[0].div_(self.temperature)
attn = decoder_out[1]
if len(decoder_out) == 3:
possible_translation_tokens = decoder_out[2]
else:
possible_translation_tokens = None
if (
hasattr(model.decoder, "adaptive_softmax")
and model.decoder.adaptive_softmax is not None
):
decoder_out[0] = decoder_out[0].unsqueeze(1)
# to use get_normalized_probs in adaptive softmax decoder
# the sample object is needed. During inference, the target
# should be set to None
log_probs = model.get_normalized_probs(
decoder_out, log_probs=True, sample={"target": None}
)
log_probs = model_weight * log_probs[:, -1, :]
else:
log_probs = model.get_normalized_probs(decoder_out, log_probs=True)
log_probs = model_weight * log_probs
all_translation_tokens.append(possible_translation_tokens)
all_log_probs.append(log_probs)
if attn is not None:
attn = attn[:, -1, :].data
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
avg_log_probs, possible_translation_tokens = SequenceGenerator.gather_probs(
all_translation_tokens=all_translation_tokens, all_probs=all_log_probs
)
if avg_attn is not None:
avg_attn.div_(len(self.models))
return avg_log_probs, avg_attn, possible_translation_tokens
def diversity_sibling_rank(self, logprobs, gamma):
"""
See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation"
for details
"""
_, beam_size, vocab_size = logprobs.size()
logprobs = logprobs.view(-1, vocab_size)
# Keep consistent with beamsearch class in fairseq
k = min(2 * beam_size, vocab_size)
_, indices = torch.topk(logprobs, k)
# Set diverse penalty as k for all words
diverse_penalty = torch.ones_like(logprobs) * k
diversity_sibling_rank = (
torch.arange(0, k).view(-1, 1).expand(k, logprobs.size(0)).type_as(logprobs)
)
# Set diversity penalty accordingly for top-k words
diverse_penalty[
torch.arange(0, logprobs.size(0)).long(), indices.transpose(0, 1)
] = diversity_sibling_rank
logprobs -= gamma * diverse_penalty
return logprobs
class BeamDecode(torch.jit.ScriptModule):
"""
Decodes the output of Beam Search to get the top hypotheses
"""
def __init__(self, eos_token_id, length_penalty, nbest, beam_size, stop_at_eos):
super().__init__()
self.eos_token_id = torch.jit.Attribute(eos_token_id, int)
self.length_penalty = torch.jit.Attribute(length_penalty, float)
self.nbest = torch.jit.Attribute(nbest, int)
self.beam_size = torch.jit.Attribute(beam_size, int)
self.stop_at_eos = torch.jit.Attribute(int(stop_at_eos), int)
@torch.jit.script_method
@torch.no_grad()
def forward(
self,
beam_tokens: Tensor,
beam_scores: Tensor,
token_weights: Tensor,
beam_prev_indices: Tensor,
num_steps: int,
) -> List[Tuple[Tensor, float, List[float], Tensor, Tensor]]:
self._check_dimensions(
beam_tokens, beam_scores, token_weights, beam_prev_indices, num_steps
)
end_states = self._get_all_end_states(
beam_tokens, beam_scores, beam_prev_indices, num_steps
)
# outputs is list of the following for each hypothesis:
# Tuple[Hypothesis, Hypothesis score, Token level scores, Attention Weights, Best indices]
outputs = torch.jit.annotate(
List[Tuple[Tensor, float, List[float], Tensor, Tensor]], []
)
for state_idx in range(len(end_states)):
state = end_states[state_idx]
hypothesis_score = float(state[0])
beam_indices = self._get_output_steps_to_beam_indices(
state, beam_prev_indices
)
beam_output = torch.jit.annotate(List[Tensor], [])
token_level_scores = torch.jit.annotate(List[float], [])
position = int(state[1])
hyp_index = int(state[2])
# best_indices represents the ending position of one hypothesis,
# the first index corresponds num_step, the second corresponds beam_index
best_indices = torch.tensor([position, hyp_index])
back_alignment_weights = []
assert position + 1 == len(beam_indices)
pos = 1
prev_beam_index = -1
while pos < len(beam_indices):
beam_index = beam_indices[pos]
beam_output.append(beam_tokens[pos][beam_index])
if pos == 1:
# beam_scores[0][:] are all 0s
token_level_scores.append(float(beam_scores[pos][beam_index]))
else:
token_level_scores.append(
float(beam_scores[pos][beam_index])
- float(beam_scores[pos - 1][prev_beam_index])
)
back_alignment_weights.append(token_weights[pos][beam_index].detach())
prev_beam_index = beam_index
pos += 1
outputs.append(
(
torch.stack(beam_output),
hypothesis_score,
token_level_scores,
torch.stack(back_alignment_weights, dim=1),
best_indices,
)
)
return outputs
@torch.jit.script_method
def _get_output_steps_to_beam_indices(
self, end_state: Tensor, beam_prev_indices: Tensor
) -> List[int]:
"""
Returns a mapping from each output position and the beam index that was
picked from the beam search results.
"""
present_position = int(end_state[1])
beam_index = int(end_state[2])
beam_indices = torch.jit.annotate(List[int], [])
while present_position >= 0:
beam_indices.insert(0, beam_index)
beam_index = int(beam_prev_indices[present_position][beam_index])
present_position = present_position - 1
return beam_indices
@torch.jit.script_method
def _add_to_end_states(
self, end_states: List[Tensor], min_score: float, state: Tensor, min_index: int
) -> Tuple[List[Tensor], float, int]:
"""
Maintains a list of atmost `nbest` highest end states
"""
if len(end_states) < self.nbest:
end_states.append(state)
# keep min_score and min_index updated
if float(state[0]) <= min_score:
min_score = float(state[0])
min_index = len(end_states) - 1
elif bool(state[0] > min_score):
# replace worst hypo with the new one
end_states[min_index] = state
# find new worst hypo, keep min_score and min_index updated
min_index = -1
# not using float("inf") temporarily bc of TorchScript bug
# using max representable value in fp16
min_score = 65504.0
for idx in range(len(end_states)):
s = end_states[idx]
if bool(float(s[0]) <= min_score):
min_index = idx
min_score = float(s[0])
return end_states, min_score, min_index
@torch.jit.script_method
def _get_all_end_states(
self,
beam_tokens: Tensor,
beam_scores: Tensor,
beam_prev_indices: Tensor,
num_steps: int,
) -> Tensor:
"""
Return all end states and hypothesis scores for those end states.
"""
# not using float("inf") temporarily bc of TorchScript bug
# using max representable value in fp16
min_score = 65504.0
min_index = -1
end_states = torch.jit.annotate(List[Tensor], [])
prev_hypo_is_finished = torch.zeros(self.beam_size).byte()
position = 1
while bool(position <= num_steps):
hypo_is_finished = torch.zeros(self.beam_size).byte()
for hyp_index in range(self.beam_size):
prev_pos = beam_prev_indices[position][hyp_index]
hypo_is_finished[hyp_index] = prev_hypo_is_finished[prev_pos]
# If hypothesis was completed in the previous index,
# then just continue
if bool(hypo_is_finished[hyp_index] == 0):
# If the present token is EOS or we have reached max_length
# then hypothesis is complete
if bool(
beam_tokens[position][hyp_index] == self.eos_token_id
) or bool(position == num_steps):
if bool(self.stop_at_eos):
hypo_is_finished[hyp_index] = 1
hypo_score = float(beam_scores[position][hyp_index])
if bool(self.length_penalty != 0):
hypo_score = hypo_score / float(position) ** float(
self.length_penalty
)
end_states, min_score, min_index = self._add_to_end_states(
end_states,
min_score,
torch.tensor(
[hypo_score, float(position), float(hyp_index)]
),
min_index,
)
prev_hypo_is_finished = hypo_is_finished
position = position + 1
end_states = torch.stack(end_states)
_, sorted_end_state_indices = end_states[:, 0].sort(dim=0, descending=True)
end_states = end_states[sorted_end_state_indices, :]
return end_states
@torch.jit.script_method
def _check_dimensions(
self,
beam_tokens: Tensor,
beam_scores: Tensor,
token_weights: Tensor,
beam_prev_indices: Tensor,
num_steps: int,
) -> None:
assert (
beam_tokens.size(1) == self.beam_size
), "Dimension of beam_tokens : {} and beam size : {} are not consistent".format(