This repository has been archived by the owner on Jul 7, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
transformer.py
2190 lines (1866 loc) · 77.1 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer model from "Attention Is All You Need".
The Transformer model consists of an encoder and a decoder. Both are stacks
of self-attention layers followed by feed-forward layers. This model yields
good results on a number of problems, especially in NLP and machine translation.
See "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) for the full
description of the model and the results obtained with its early version.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range # pylint: disable=redefined-builtin
from tensor2tensor.data_generators import librispeech
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import beam_search
from tensor2tensor.utils import expert_utils
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow as tf
from tensorflow.python.util import nest
@registry.register_model
class Transformer(t2t_model.T2TModel):
"""Attention net. See file docstring."""
def __init__(self, *args, **kwargs):
super(Transformer, self).__init__(*args, **kwargs)
self.attention_weights = dict() # For visualizing attention heads.
def encode(self, inputs, target_space, hparams, features=None, losses=None):
"""Encode transformer inputs.
Args:
inputs: Transformer inputs [batch_size, input_length, input_height,
hidden_dim] which will be flattened along the two spatial dimensions.
target_space: scalar, target space ID.
hparams: hyperparameters for model.
features: optionally pass the entire features dictionary as well.
This is needed now for "packed" datasets.
losses: optional list onto which to append extra training losses
Returns:
Tuple of:
encoder_output: Encoder representation.
[batch_size, input_length, hidden_dim]
encoder_decoder_attention_bias: Bias and mask weights for
encoder-decoder attention. [batch_size, input_length]
"""
inputs = common_layers.flatten4d3d(inputs)
encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
transformer_prepare_encoder(
inputs, target_space, hparams, features=features))
encoder_input = tf.nn.dropout(encoder_input,
1.0 - hparams.layer_prepostprocess_dropout)
encoder_output = transformer_encoder(
encoder_input,
self_attention_bias,
hparams,
nonpadding=features_to_nonpadding(features, "inputs"),
save_weights_to=self.attention_weights,
losses=losses)
return encoder_output, encoder_decoder_attention_bias
def decode(self,
decoder_input,
encoder_output,
encoder_decoder_attention_bias,
decoder_self_attention_bias,
hparams,
cache=None,
decode_loop_step=None,
nonpadding=None,
losses=None):
"""Decode Transformer outputs from encoder representation.
Args:
decoder_input: inputs to bottom of the model.
[batch_size, decoder_length, hidden_dim]
encoder_output: Encoder representation.
[batch_size, input_length, hidden_dim]
encoder_decoder_attention_bias: Bias and mask weights for
encoder-decoder attention. [batch_size, input_length]
decoder_self_attention_bias: Bias and mask weights for decoder
self-attention. [batch_size, decoder_length]
hparams: hyperparameters for model.
cache: dict, containing tensors which are the results of previous
attentions, used for fast decoding.
decode_loop_step: An integer, step number of the decoding loop.
Only used for inference on TPU.
nonpadding: optional Tensor with shape [batch_size, decoder_length]
losses: optional list onto which to append extra training losses
Returns:
Final decoder representation. [batch_size, decoder_length, hidden_dim]
"""
decoder_input = tf.nn.dropout(decoder_input,
1.0 - hparams.layer_prepostprocess_dropout)
decoder_output = transformer_decoder(
decoder_input,
encoder_output,
decoder_self_attention_bias,
encoder_decoder_attention_bias,
hparams,
cache=cache,
decode_loop_step=decode_loop_step,
nonpadding=nonpadding,
save_weights_to=self.attention_weights,
losses=losses)
if (common_layers.is_on_tpu() and
hparams.mode == tf.estimator.ModeKeys.TRAIN):
# TPU does not react kindly to extra dimensions.
# TODO(noam): remove this once TPU is more forgiving of extra dims.
return decoder_output
else:
# Expand since t2t expects 4d tensors.
return tf.expand_dims(decoder_output, axis=2)
def body(self, features):
"""Transformer main model_fn.
Args:
features: Map of features to the model. Should contain the following:
"inputs": Transformer inputs [batch_size, input_length, hidden_dim]
"targets": Target decoder outputs.
[batch_size, decoder_length, hidden_dim]
"target_space_id": A scalar int from data_generators.problem.SpaceID.
Returns:
Final decoder representation. [batch_size, decoder_length, hidden_dim]
"""
hparams = self._hparams
losses = []
if self.has_input:
inputs = features["inputs"]
target_space = features["target_space_id"]
encoder_output, encoder_decoder_attention_bias = self.encode(
inputs, target_space, hparams, features=features, losses=losses)
else:
encoder_output, encoder_decoder_attention_bias = (None, None)
targets = features["targets"]
targets_shape = common_layers.shape_list(targets)
targets = common_layers.flatten4d3d(targets)
decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(
targets, hparams, features=features)
decoder_output = self.decode(
decoder_input,
encoder_output,
encoder_decoder_attention_bias,
decoder_self_attention_bias,
hparams,
nonpadding=features_to_nonpadding(features, "targets"),
losses=losses)
expected_attentions = features.get("expected_attentions")
if expected_attentions is not None:
attention_loss = common_attention.encoder_decoder_attention_loss(
expected_attentions, self.attention_weights,
hparams.expected_attention_loss_type,
hparams.expected_attention_loss_multiplier)
return decoder_output, {"attention_loss": attention_loss}
ret = tf.reshape(decoder_output, targets_shape)
if losses:
return ret, {"extra_loss": tf.add_n(losses)}
else:
return ret
def _greedy_infer(self, features, decode_length, use_tpu=False):
"""Fast version of greedy decoding.
Args:
features: an map of string to `Tensor`
decode_length: an integer. How many additional timesteps to decode.
use_tpu: A bool. Whether to build the inference graph for TPU.
Returns:
A dict of decoding results {
"outputs": integer `Tensor` of decoded ids of shape
[batch_size, <= decode_length] if beam_size == 1 or
[batch_size, top_beams, <= decode_length]
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1)
}
Raises:
NotImplementedError: If there are multiple data shards.
"""
# For real-valued modalities use the slow decode path for now.
if self._target_modality_is_real:
return super(Transformer, self)._greedy_infer(features, decode_length)
with tf.variable_scope(self.name):
return (self._fast_decode_tpu(features, decode_length) if use_tpu else
self._fast_decode(features, decode_length))
def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
"""Beam search decoding.
Args:
features: an map of string to `Tensor`
decode_length: an integer. How many additional timesteps to decode.
beam_size: number of beams.
top_beams: an integer. How many of the beams to return.
alpha: Float that controls the length penalty. larger the alpha, stronger
the preference for longer translations.
Returns:
A dict of decoding results {
"outputs": integer `Tensor` of decoded ids of shape
[batch_size, <= decode_length] if beam_size == 1 or
[batch_size, top_beams, <= decode_length]
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1)
}
"""
if self._hparams.self_attention_type != "dot_product":
# Caching is not guaranteed to work with attention types other than
# dot_product.
# TODO(petershaw): Support fast decoding when using relative
# position representations, i.e. "dot_product_relative" attention.
return self._beam_decode_slow(features, decode_length, beam_size,
top_beams, alpha)
with tf.variable_scope(self.name):
return self._fast_decode(features, decode_length, beam_size, top_beams,
alpha)
def _fast_decode_tpu(self,
features,
decode_length,
beam_size=1):
"""Fast decoding.
Implements only greedy decoding on TPU.
Args:
features: A map of string to model features.
decode_length: An integer, how many additional timesteps to decode.
beam_size: An integer, number of beams.
Returns:
A dict of decoding results {
"outputs": integer `Tensor` of decoded ids of shape
[batch_size, <= decode_length]
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1)
}.
Raises:
NotImplementedError: If there are multiple data shards or beam_size > 1.
"""
if self._num_datashards != 1:
raise NotImplementedError("Fast decoding only supports a single shard.")
if "targets_segmentation" in features:
raise NotImplementedError(
"Decoding not supported on packed datasets "
" If you want to decode from a dataset, use the non-packed version"
" of the dataset when decoding.")
dp = self._data_parallelism
hparams = self._hparams
target_modality = self._problem_hparams.target_modality
if self.has_input:
inputs = features["inputs"]
if target_modality.is_class_modality:
decode_length = 1
else:
decode_length = (
common_layers.shape_list(inputs)[1] + features.get(
"decode_length", decode_length))
# TODO(llion): Clean up this reshaping logic.
inputs = tf.expand_dims(inputs, axis=1)
if len(inputs.shape) < 5:
inputs = tf.expand_dims(inputs, axis=4)
s = common_layers.shape_list(inputs)
batch_size = s[0]
inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
# _shard_features called to ensure that the variable names match
inputs = self._shard_features({"inputs": inputs})["inputs"]
input_modality = self._problem_hparams.input_modality["inputs"]
with tf.variable_scope(input_modality.name):
inputs = input_modality.bottom_sharded(inputs, dp)
with tf.variable_scope("body"):
encoder_output, encoder_decoder_attention_bias = dp(
self.encode,
inputs,
features["target_space_id"],
hparams,
features=features)
encoder_output = encoder_output[0]
encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
partial_targets = None
else:
# The problem has no inputs.
encoder_output = None
encoder_decoder_attention_bias = None
# Prepare partial targets.
# In either features["inputs"] or features["targets"].
# We force the outputs to begin with these sequences.
partial_targets = features.get("inputs")
if partial_targets is None:
partial_targets = features["targets"]
assert partial_targets is not None
partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
partial_targets = tf.to_int64(partial_targets)
partial_targets_shape = common_layers.shape_list(partial_targets)
partial_targets_length = partial_targets_shape[1]
decode_length = (
partial_targets_length + features.get("decode_length", decode_length))
batch_size = partial_targets_shape[0]
if hparams.pos == "timing":
positional_encoding = common_attention.get_timing_signal_1d(
decode_length + 1, hparams.hidden_size)
elif hparams.pos == "emb":
positional_encoding = common_attention.add_positional_embedding(
tf.zeros([1, decode_length + 1, hparams.hidden_size]),
hparams.max_length, "body/targets_positional_embedding", None)
else:
positional_encoding = None
def preprocess_targets(targets, i):
"""Performs preprocessing steps on the targets to prepare for the decoder.
This includes:
- Embedding the ids.
- Flattening to 3D tensor.
- Optionally adding timing signals.
Args:
targets: A tensor, inputs ids to the decoder. [batch_size, 1].
i: An integer, Step number of the decoding loop.
Returns:
A tensor, processed targets [batch_size, 1, hidden_dim].
"""
# _shard_features called to ensure that the variable names match
targets = self._shard_features({"targets": targets})["targets"]
with tf.variable_scope(target_modality.name):
targets = target_modality.targets_bottom_sharded(targets, dp)[0]
targets = common_layers.flatten4d3d(targets)
# TODO(llion): Explain! Is this even needed?
targets = tf.cond(
tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)
if positional_encoding is not None:
positional_encoding_shape = positional_encoding.shape.as_list()
targets += tf.slice(
positional_encoding, [0, i, 0],
[positional_encoding_shape[0], 1, positional_encoding_shape[2]])
return targets
decoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(decode_length))
if hparams.proximity_bias:
decoder_self_attention_bias += common_attention.attention_bias_proximal(
decode_length)
def symbols_to_logits_tpu_fn(ids, i, cache):
"""Go from ids to logits for next symbol on TPU.
Args:
ids: A tensor, symbol IDs.
i: An integer, step number of the decoding loop. Only used for inference
on TPU.
cache: A dict, containing tensors which are the results of previous
attentions, used for fast decoding.
Returns:
ret: A tensor, computed logits.
cache: A dict, containing tensors which are the results of previous
attentions, used for fast decoding.
"""
ids = ids[:, -1:]
targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
targets = preprocess_targets(targets, i)
bias_shape = decoder_self_attention_bias.shape.as_list()
bias = tf.slice(decoder_self_attention_bias, [0, 0, i, 0],
[bias_shape[0], bias_shape[1], 1, bias_shape[3]])
bias_padding = tf.fill([bias_shape[0], bias_shape[1], 1], -1e9)
tmp_bias = tf.transpose(bias, perm=[3, 0, 1, 2])
bias_index = i + 1
while_condition = lambda bias_index, _: tf.less(bias_index, decode_length)
def while_body(bias_index, tmp_bias):
tmp_bias = common_layers.tf_inplace_ops().alias_inplace_update(
tmp_bias, bias_index, bias_padding)
return bias_index + 1, tmp_bias
_, tmp_bias = tf.while_loop(
while_condition, while_body, (bias_index, tmp_bias))
bias = tf.transpose(tmp_bias, perm=[1, 2, 3, 0])
with tf.variable_scope("body"):
body_outputs = dp(
self.decode,
targets,
cache.get("encoder_output"),
cache.get("encoder_decoder_attention_bias"),
bias,
hparams,
cache,
i,
nonpadding=features_to_nonpadding(features, "targets"))
with tf.variable_scope(target_modality.name):
logits = target_modality.top_sharded(body_outputs, None, dp)[0]
ret = tf.squeeze(logits, axis=[1, 2, 3])
if partial_targets is not None:
# If the position is within the given partial targets, we alter the
# logits to always return those values.
# A faster approach would be to process the partial targets in one
# iteration in order to fill the corresponding parts of the cache.
# This would require broader changes, though.
vocab_size = tf.shape(ret)[1]
def forced_logits():
return tf.one_hot(
tf.tile(
tf.slice(partial_targets, [0, i],
[partial_targets.shape.as_list()[0], 1]),
[beam_size]), vocab_size, 0.0, -1e9)
ret = tf.cond(
tf.less(i, partial_targets_length), forced_logits, lambda: ret)
return ret, cache
ret = fast_decode_tpu(
encoder_output=encoder_output,
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
symbols_to_logits_fn=symbols_to_logits_tpu_fn,
hparams=hparams,
decode_length=decode_length,
beam_size=beam_size,
batch_size=batch_size,
force_decode_length=self._decode_hparams.force_decode_length)
if partial_targets is not None:
ret["outputs"] = ret["outputs"][:, partial_targets_length:]
return ret
def _fast_decode(self,
features,
decode_length,
beam_size=1,
top_beams=1,
alpha=1.0):
"""Fast decoding.
Implements both greedy and beam search decoding, uses beam search iff
beam_size > 1, otherwise beam search related arguments are ignored.
Args:
features: a map of string to model features.
decode_length: an integer. How many additional timesteps to decode.
beam_size: number of beams.
top_beams: an integer. How many of the beams to return.
alpha: Float that controls the length penalty. larger the alpha, stronger
the preference for longer translations.
Returns:
A dict of decoding results {
"outputs": integer `Tensor` of decoded ids of shape
[batch_size, <= decode_length] if beam_size == 1 or
[batch_size, top_beams, <= decode_length]
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1)
}
Raises:
NotImplementedError: If there are multiple data shards.
"""
if self._num_datashards != 1:
raise NotImplementedError("Fast decoding only supports a single shard.")
dp = self._data_parallelism
hparams = self._hparams
target_modality = self._problem_hparams.target_modality
if "targets_segmentation" in features:
raise NotImplementedError(
"Decoding not supported on packed datasets "
" If you want to decode from a dataset, use the non-packed version"
" of the dataset when decoding.")
if self.has_input:
inputs = features["inputs"]
if target_modality.is_class_modality:
decode_length = 1
else:
decode_length = (
common_layers.shape_list(inputs)[1] + features.get(
"decode_length", decode_length))
# TODO(llion): Clean up this reshaping logic.
inputs = tf.expand_dims(inputs, axis=1)
if len(inputs.shape) < 5:
inputs = tf.expand_dims(inputs, axis=4)
s = common_layers.shape_list(inputs)
batch_size = s[0]
inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
# _shard_features called to ensure that the variable names match
inputs = self._shard_features({"inputs": inputs})["inputs"]
input_modality = self._problem_hparams.input_modality["inputs"]
with tf.variable_scope(input_modality.name):
inputs = input_modality.bottom_sharded(inputs, dp)
with tf.variable_scope("body"):
encoder_output, encoder_decoder_attention_bias = dp(
self.encode,
inputs,
features["target_space_id"],
hparams,
features=features)
encoder_output = encoder_output[0]
encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
partial_targets = None
else:
# The problem has no inputs.
encoder_output = None
encoder_decoder_attention_bias = None
# Prepare partial targets.
# In either features["inputs"] or features["targets"].
# We force the outputs to begin with these sequences.
partial_targets = features.get("inputs")
if partial_targets is None:
partial_targets = features["targets"]
assert partial_targets is not None
partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
partial_targets = tf.to_int64(partial_targets)
partial_targets_shape = common_layers.shape_list(partial_targets)
partial_targets_length = partial_targets_shape[1]
decode_length = (
partial_targets_length + features.get("decode_length", decode_length))
batch_size = partial_targets_shape[0]
if hparams.pos == "timing":
positional_encoding = common_attention.get_timing_signal_1d(
decode_length + 1, hparams.hidden_size)
elif hparams.pos == "emb":
positional_encoding = common_attention.add_positional_embedding(
tf.zeros([1, decode_length, hparams.hidden_size]),
hparams.max_length, "body/targets_positional_embedding", None)
else:
positional_encoding = None
def preprocess_targets(targets, i):
"""Performs preprocessing steps on the targets to prepare for the decoder.
This includes:
- Embedding the ids.
- Flattening to 3D tensor.
- Optionally adding timing signals.
Args:
targets: inputs ids to the decoder. [batch_size, 1]
i: scalar, Step number of the decoding loop.
Returns:
Processed targets [batch_size, 1, hidden_dim]
"""
# _shard_features called to ensure that the variable names match
targets = self._shard_features({"targets": targets})["targets"]
with tf.variable_scope(target_modality.name):
targets = target_modality.targets_bottom_sharded(targets, dp)[0]
targets = common_layers.flatten4d3d(targets)
# TODO(llion): Explain! Is this even needed?
targets = tf.cond(
tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)
if positional_encoding is not None:
targets += positional_encoding[:, i:i + 1]
return targets
decoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(decode_length))
if hparams.proximity_bias:
decoder_self_attention_bias += common_attention.attention_bias_proximal(
decode_length)
def symbols_to_logits_fn(ids, i, cache):
"""Go from ids to logits for next symbol."""
ids = ids[:, -1:]
targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
targets = preprocess_targets(targets, i)
bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
with tf.variable_scope("body"):
body_outputs = dp(
self.decode,
targets,
cache.get("encoder_output"),
cache.get("encoder_decoder_attention_bias"),
bias,
hparams,
cache,
nonpadding=features_to_nonpadding(features, "targets"))
with tf.variable_scope(target_modality.name):
logits = target_modality.top_sharded(body_outputs, None, dp)[0]
ret = tf.squeeze(logits, axis=[1, 2, 3])
if partial_targets is not None:
# If the position is within the given partial targets, we alter the
# logits to always return those values.
# A faster approach would be to process the partial targets in one
# iteration in order to fill the corresponding parts of the cache.
# This would require broader changes, though.
vocab_size = tf.shape(ret)[1]
def forced_logits():
return tf.one_hot(
tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
-1e9)
ret = tf.cond(
tf.less(i, partial_targets_length), forced_logits, lambda: ret)
return ret, cache
ret = fast_decode(
encoder_output=encoder_output,
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
symbols_to_logits_fn=symbols_to_logits_fn,
hparams=hparams,
decode_length=decode_length,
vocab_size=target_modality.top_dimensionality,
beam_size=beam_size,
top_beams=top_beams,
alpha=alpha,
batch_size=batch_size,
force_decode_length=self._decode_hparams.force_decode_length)
if partial_targets is not None:
if beam_size <= 1 or top_beams <= 1:
ret["outputs"] = ret["outputs"][:, partial_targets_length:]
else:
ret["outputs"] = ret["outputs"][:, :, partial_targets_length:]
return ret
def fast_decode_tpu(encoder_output,
encoder_decoder_attention_bias,
symbols_to_logits_fn,
hparams,
decode_length,
beam_size=1,
eos_id=beam_search.EOS_ID,
batch_size=None,
force_decode_length=False):
"""Given encoder output and a symbols to logits function, does fast decoding.
Implements only greedy decoding for TPU.
Args:
encoder_output: A tensor, output from encoder.
encoder_decoder_attention_bias: A tensor, bias for use in encoder-decoder
attention.
symbols_to_logits_fn: Incremental decoding, function mapping triple
`(ids, step, cache)` to symbol logits.
hparams: Run hyperparameters.
decode_length: An integer, how many additional timesteps to decode.
beam_size: An integer, number of beams.
eos_id: End-of-sequence symbol.
batch_size: An integer, must be passed if there is no input.
force_decode_length: A bool, whether to force the full decode length, or if
False, stop when all beams hit eos_id.
Returns:
A dict of decoding results {
"outputs": integer `Tensor` of decoded ids of shape
[batch_size, <= decode_length]
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1)
}.
Raises:
NotImplementedError: If beam size > 1.
"""
if encoder_output is not None:
batch_size = common_layers.shape_list(encoder_output)[0]
key_channels = hparams.attention_key_channels or hparams.hidden_size
value_channels = hparams.attention_value_channels or hparams.hidden_size
num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
cache = {
"layer_%d" % layer: {
"k":
common_attention.split_heads(
tf.zeros([batch_size, decode_length, key_channels]),
hparams.num_heads),
"v":
common_attention.split_heads(
tf.zeros([batch_size, decode_length, value_channels]),
hparams.num_heads),
"f":
tf.zeros([batch_size, decode_length, hparams.hidden_size]),
} for layer in range(num_layers)
}
if encoder_output is not None:
for layer in range(num_layers):
layer_name = "layer_%d" % layer
with tf.variable_scope(
"body/decoder/%s/encdec_attention/multihead_attention" % layer_name):
k_encdec = common_attention.compute_attention_component(
encoder_output, key_channels, name="k")
k_encdec = common_attention.split_heads(k_encdec, hparams.num_heads)
v_encdec = common_attention.compute_attention_component(
encoder_output, value_channels, name="v")
v_encdec = common_attention.split_heads(v_encdec, hparams.num_heads)
cache[layer_name]["k_encdec"] = k_encdec
cache[layer_name]["v_encdec"] = v_encdec
cache["encoder_output"] = encoder_output
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
if beam_size > 1: # Beam Search
raise NotImplementedError("Beam search inference on TPU is not supported")
# Greedy
def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob):
"""One step of greedy decoding."""
logits, cache = symbols_to_logits_fn(next_id, i, cache)
log_probs = common_layers.log_prob_from_logits(logits)
temperature = (0.0 if hparams.sampling_method == "argmax" else
hparams.sampling_temp)
next_id = common_layers.sample_with_temperature(logits, temperature)
hit_eos |= tf.equal(next_id, eos_id)
log_prob_indices = tf.stack(
[tf.range(tf.to_int64(batch_size)), next_id], axis=1)
log_prob += tf.gather_nd(log_probs, log_prob_indices)
next_id = tf.expand_dims(next_id, axis=1)
decoded_ids = tf.transpose(decoded_ids)
decoded_ids = common_layers.tf_inplace_ops().alias_inplace_update(
decoded_ids, i, tf.squeeze(next_id, axis=1))
decoded_ids = tf.transpose(decoded_ids)
return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob
def is_not_finished(i, hit_eos, *_):
finished = i >= decode_length
if not force_decode_length:
finished |= tf.reduce_all(hit_eos)
return tf.logical_not(finished)
decoded_ids = tf.zeros([batch_size, decode_length], dtype=tf.int64)
hit_eos = tf.fill([batch_size], False)
next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)
def compute_cache_shape_invariants(tensor):
return tf.TensorShape(tensor.shape.as_list())
_, _, _, decoded_ids, _, log_prob = tf.while_loop(
is_not_finished,
inner_loop, [
tf.constant(0), hit_eos, next_id, decoded_ids, cache,
initial_log_prob
],
shape_invariants=[
tf.TensorShape([]),
tf.TensorShape([batch_size]),
tf.TensorShape([batch_size, 1]),
tf.TensorShape([batch_size, decode_length]),
nest.map_structure(compute_cache_shape_invariants, cache),
tf.TensorShape([batch_size]),
])
scores = log_prob
return {"outputs": decoded_ids, "scores": scores}
def fast_decode(encoder_output,
encoder_decoder_attention_bias,
symbols_to_logits_fn,
hparams,
decode_length,
vocab_size,
beam_size=1,
top_beams=1,
alpha=1.0,
eos_id=beam_search.EOS_ID,
batch_size=None,
force_decode_length=False):
"""Given encoder output and a symbols to logits function, does fast decoding.
Implements both greedy and beam search decoding, uses beam search iff
beam_size > 1, otherwise beam search related arguments are ignored.
Args:
encoder_output: Output from encoder.
encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
attention
symbols_to_logits_fn: Incremental decoding; function mapping triple
`(ids, step, cache)` to symbol logits.
hparams: run hyperparameters
decode_length: an integer. How many additional timesteps to decode.
vocab_size: Output vocabulary size.
beam_size: number of beams.
top_beams: an integer. How many of the beams to return.
alpha: Float that controls the length penalty. larger the alpha, stronger
the preference for longer translations.
eos_id: End-of-sequence symbol in beam search.
batch_size: an integer scalar - must be passed if there is no input
force_decode_length: bool, whether to force the full decode length, or if
False, stop when all beams hit eos_id.
Returns:
A dict of decoding results {
"outputs": integer `Tensor` of decoded ids of shape
[batch_size, <= decode_length] if top_beams == 1 or
[batch_size, top_beams, <= decode_length] otherwise
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1)
}
Raises:
NotImplementedError: If beam size > 1 with partial targets.
"""
if encoder_output is not None:
batch_size = common_layers.shape_list(encoder_output)[0]
key_channels = hparams.attention_key_channels or hparams.hidden_size
value_channels = hparams.attention_value_channels or hparams.hidden_size
num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
vars_3d_num_heads = (
hparams.num_heads if hparams.get("attention_variables_3d") else 0)
cache = {
"layer_%d" % layer: {
"k":
common_attention.split_heads(
tf.zeros([batch_size, 0, key_channels]), hparams.num_heads),
"v":
common_attention.split_heads(
tf.zeros([batch_size, 0, value_channels]), hparams.num_heads),
"f":
tf.zeros([batch_size, 0, hparams.hidden_size]),
} for layer in range(num_layers)
}
if encoder_output is not None:
for layer in range(num_layers):
layer_name = "layer_%d" % layer
with tf.variable_scope(
"body/decoder/%s/encdec_attention/multihead_attention" % layer_name):
k_encdec = common_attention.compute_attention_component(
encoder_output, key_channels, name="k",
vars_3d_num_heads=vars_3d_num_heads)
k_encdec = common_attention.split_heads(k_encdec, hparams.num_heads)
v_encdec = common_attention.compute_attention_component(
encoder_output, value_channels, name="v",
vars_3d_num_heads=vars_3d_num_heads)
v_encdec = common_attention.split_heads(v_encdec, hparams.num_heads)
cache[layer_name]["k_encdec"] = k_encdec
cache[layer_name]["v_encdec"] = v_encdec
cache["encoder_output"] = encoder_output
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
if beam_size > 1: # Beam Search
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
decoded_ids, scores = beam_search.beam_search(
symbols_to_logits_fn,
initial_ids,
beam_size,
decode_length,
vocab_size,
alpha,
states=cache,
eos_id=eos_id,
stop_early=(top_beams == 1))
if top_beams == 1:
decoded_ids = decoded_ids[:, 0, 1:]
scores = scores[:, 0]
else:
decoded_ids = decoded_ids[:, :top_beams, 1:]
scores = scores[:, :top_beams]
else: # Greedy
def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob):
"""One step of greedy decoding."""
logits, cache = symbols_to_logits_fn(next_id, i, cache)
log_probs = common_layers.log_prob_from_logits(logits)
temperature = (0.0 if hparams.sampling_method == "argmax" else
hparams.sampling_temp)
next_id = common_layers.sample_with_temperature(logits, temperature)
hit_eos |= tf.equal(next_id, eos_id)
log_prob_indices = tf.stack(
[tf.range(tf.to_int64(batch_size)), next_id], axis=1)
log_prob += tf.gather_nd(log_probs, log_prob_indices)
next_id = tf.expand_dims(next_id, axis=1)
decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob
def is_not_finished(i, hit_eos, *_):
finished = i >= decode_length
if not force_decode_length:
finished |= tf.reduce_all(hit_eos)
return tf.logical_not(finished)
decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
hit_eos = tf.fill([batch_size], False)
next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)
_, _, _, decoded_ids, _, log_prob = tf.while_loop(
is_not_finished,
inner_loop, [
tf.constant(0), hit_eos, next_id, decoded_ids, cache,
initial_log_prob
],
shape_invariants=[
tf.TensorShape([]),
tf.TensorShape([None]),
tf.TensorShape([None, None]),
tf.TensorShape([None, None]),
nest.map_structure(beam_search.get_state_shape_invariants, cache),
tf.TensorShape([None]),
])
scores = log_prob
return {"outputs": decoded_ids, "scores": scores}
@registry.register_model
class TransformerScorer(Transformer):
"""Transformer model, but only scores in PREDICT mode.
Checkpoints between Transformer and TransformerScorer are interchangeable.
"""
def __init__(self, *args, **kwargs):
super(TransformerScorer, self).__init__(*args, **kwargs)
self._name = "transformer"
self._base_name = "transformer"
def infer(self,
features=None,
decode_length=50,
beam_size=1,
top_beams=1,
alpha=0.0,
use_tpu=False):
"""Returns the targets and their log probabilities."""
del decode_length, beam_size, top_beams, alpha, use_tpu
assert features is not None
# Run the model
self.hparams.force_full_predict = True
with tf.variable_scope(self.name):
logits, _ = self.model_fn(features)
assert len(logits.shape) == 5 # [batch, time, 1, 1, vocab]
logits = tf.squeeze(logits, [2, 3])
# Compute the log probabilities
log_probs = common_layers.log_prob_from_logits(logits)
targets = features["targets"]
assert len(targets.shape) == 4 # [batch, time, 1, 1]
targets = tf.squeeze(targets, [2, 3])