This repository has been archived by the owner on Jul 7, 2023. It is now read-only.
/
transformer.py
2978 lines (2543 loc) · 107 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer model from "Attention Is All You Need".
The Transformer model consists of an encoder and a decoder. Both are stacks
of self-attention layers followed by feed-forward layers. This model yields
good results on a number of problems, especially in NLP and machine translation.
See "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) for the full
description of the model and the results obtained with its early version.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range # pylint: disable=redefined-builtin
from tensor2tensor.data_generators import librispeech
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.layers import modalities
from tensor2tensor.layers import transformer_layers
from tensor2tensor.layers import transformer_memory
from tensor2tensor.utils import beam_search
from tensor2tensor.utils import expert_utils
from tensor2tensor.utils import mlperf_log
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops import inplace_ops
from tensorflow.python.util import nest
# pylint: enable=g-direct-tensorflow-import
# Alias some commonly reused layers, here and elsewhere.
transformer_prepare_encoder = transformer_layers.transformer_prepare_encoder
transformer_encoder = transformer_layers.transformer_encoder
transformer_ffn_layer = transformer_layers.transformer_ffn_layer
def transformer_encode(encoder_function, inputs, target_space, hparams,
attention_weights=None, features=None, losses=None,
prepare_encoder_fn=None, **kwargs):
"""Encode transformer inputs.
Args:
encoder_function: the encoder function
inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim] which
will be flattened along the two spatial dimensions.
target_space: scalar, target space ID.
hparams: hyperparameters for model.
attention_weights: weight to store attention to.
features: optionally pass the entire features dictionary as well. This is
needed now for "packed" datasets.
losses: optional list onto which to append extra training losses
prepare_encoder_fn: optional, alternative to transformer_prepare_encoder.
**kwargs: additional arguments to pass to encoder_function
Returns:
Tuple of:
encoder_output: Encoder representation.
[batch_size, input_length, hidden_dim]
encoder_decoder_attention_bias: Bias and mask weights for
encoder-decoder attention. [batch_size, input_length]
"""
inputs = common_layers.flatten4d3d(inputs)
if not prepare_encoder_fn:
prepare_encoder_fn = transformer_prepare_encoder
encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
prepare_encoder_fn(
inputs, target_space, hparams, features=features))
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT,
value=hparams.layer_prepostprocess_dropout,
hparams=hparams)
encoder_input = tf.nn.dropout(encoder_input,
1.0 - hparams.layer_prepostprocess_dropout)
attn_bias_for_padding = None
# Otherwise the encoder will just use encoder_self_attention_bias.
if hparams.unidirectional_encoder:
attn_bias_for_padding = encoder_decoder_attention_bias
encoder_output = encoder_function(
encoder_input,
self_attention_bias,
hparams,
nonpadding=features_to_nonpadding(features, "inputs"),
save_weights_to=attention_weights,
make_image_summary=not common_layers.is_xla_compiled(),
losses=losses,
attn_bias_for_padding=attn_bias_for_padding,
**kwargs)
return encoder_output, encoder_decoder_attention_bias
def transformer_decode(decoder_function,
decoder_input,
encoder_output,
encoder_decoder_attention_bias,
decoder_self_attention_bias,
hparams,
attention_weights=None,
cache=None,
decode_loop_step=None,
nonpadding=None,
losses=None,
**kwargs):
"""Decode Transformer outputs from encoder representation.
Args:
decoder_function: the decoder function
decoder_input: inputs to bottom of the model. [batch_size, decoder_length,
hidden_dim]
encoder_output: Encoder representation. [batch_size, input_length,
hidden_dim]
encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder
attention. [batch_size, input_length]
decoder_self_attention_bias: Bias and mask weights for decoder
self-attention. [batch_size, decoder_length]
hparams: hyperparameters for model.
attention_weights: weight to store attention to.
cache: dict, containing tensors which are the results of previous
attentions, used for fast decoding.
decode_loop_step: An integer, step number of the decoding loop. Only used
for inference on TPU.
nonpadding: optional Tensor with shape [batch_size, decoder_length]
losses: optional list onto which to append extra training losses
**kwargs: additional arguments to pass to decoder_function
Returns:
Final decoder representation. [batch_size, decoder_length, hidden_dim]
"""
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT,
value=hparams.layer_prepostprocess_dropout,
hparams=hparams)
decoder_input = tf.nn.dropout(decoder_input,
1.0 - hparams.layer_prepostprocess_dropout)
decoder_output = decoder_function(
decoder_input,
encoder_output,
decoder_self_attention_bias,
encoder_decoder_attention_bias,
hparams,
cache=cache,
decode_loop_step=decode_loop_step,
nonpadding=nonpadding,
save_weights_to=attention_weights,
losses=losses,
**kwargs)
if (common_layers.is_xla_compiled() and
hparams.mode == tf_estimator.ModeKeys.TRAIN):
# TPU does not react kindly to extra dimensions.
# TODO(noam): remove this once TPU is more forgiving of extra dims.
return decoder_output
else:
# Expand since t2t expects 4d tensors.
return tf.expand_dims(decoder_output, axis=2)
@registry.register_model
class Transformer(t2t_model.T2TModel):
"""Attention net. See file docstring."""
def __init__(self, *args, **kwargs):
super(Transformer, self).__init__(*args, **kwargs)
self.attention_weights = {} # For visualizing attention heads.
self.recurrent_memory_by_layer = None # Override to enable recurrent memory
self._encoder_function = transformer_encoder
self._decoder_function = transformer_decoder
self._init_cache_fn = _init_transformer_cache
self._prepare_encoder_fn = transformer_prepare_encoder
self._prepare_decoder_fn = transformer_prepare_decoder
def encode(self, inputs, target_space, hparams, features=None, losses=None):
"""Encode transformer inputs, see transformer_encode."""
return transformer_encode(
self._encoder_function, inputs, target_space, hparams,
attention_weights=self.attention_weights,
features=features, losses=losses,
prepare_encoder_fn=self._prepare_encoder_fn)
def decode(self,
decoder_input,
encoder_output,
encoder_decoder_attention_bias,
decoder_self_attention_bias,
hparams,
cache=None,
decode_loop_step=None,
nonpadding=None,
losses=None,
**kwargs):
"""Decode Transformer outputs, see transformer_decode."""
return transformer_decode(
self._decoder_function, decoder_input, encoder_output,
encoder_decoder_attention_bias, decoder_self_attention_bias,
hparams, attention_weights=self.attention_weights, cache=cache,
decode_loop_step=decode_loop_step, nonpadding=nonpadding, losses=losses,
**kwargs)
def body(self, features):
"""Transformer main model_fn.
Args:
features: Map of features to the model. Should contain the following:
"inputs": Transformer inputs. [batch_size, input_length, 1,
hidden_dim].
"targets": Target decoder outputs. [batch_size, decoder_length, 1,
hidden_dim]
"target_space_id": A scalar int from data_generators.problem.SpaceID.
Returns:
Final decoder representation. [batch_size, decoder_length, hidden_dim]
"""
hparams = self._hparams
losses = []
if self.has_input:
inputs = self._prepare_inputs_for_body(features)
target_space = features["target_space_id"]
encoder_output, encoder_decoder_attention_bias = self.encode(
inputs, target_space, hparams, features=features, losses=losses)
else:
encoder_output, encoder_decoder_attention_bias = (None, None)
targets = features["targets"]
targets_shape = common_layers.shape_list(targets)
targets = common_layers.flatten4d3d(targets)
decoder_input, decoder_self_attention_bias = self._prepare_decoder_fn(
targets, hparams, features=features)
# Not all subclasses of Transformer support keyword arguments related to
# recurrent memory, so only pass these arguments if memory is enabled.
decode_kwargs = {}
if self.recurrent_memory_by_layer is not None:
# TODO(kitaev): The chunk_number feature currently has the same shape as
# "targets", but this is only for the purposes of sharing sharding code.
# In fact every token within an example must have the same chunk number.
chunk_number_each_token = tf.squeeze(features["chunk_number"], (-1, -2))
chunk_number_each_example = chunk_number_each_token[:, 0]
# Uncomment the code below to verify that tokens within a batch share the
# same chunk number:
# with tf.control_dependencies([
# tf.assert_equal(chunk_number_each_token,
# chunk_number_each_example[:, None])
# ]):
# chunk_number_each_example = tf.identity(chunk_number_each_example)
decode_kwargs = dict(
recurrent_memory_by_layer=self.recurrent_memory_by_layer,
chunk_number=chunk_number_each_example,
)
decoder_output = self.decode(
decoder_input,
encoder_output,
encoder_decoder_attention_bias,
decoder_self_attention_bias,
hparams,
nonpadding=features_to_nonpadding(features, "targets"),
losses=losses,
**decode_kwargs
)
expected_attentions = features.get("expected_attentions")
if expected_attentions is not None:
attention_loss = common_attention.encoder_decoder_attention_loss(
expected_attentions, self.attention_weights,
hparams.expected_attention_loss_type,
hparams.expected_attention_loss_multiplier)
return decoder_output, {"attention_loss": attention_loss}
ret = tf.reshape(decoder_output, targets_shape)
if losses:
return ret, {"extra_loss": tf.add_n(losses)}
else:
return ret
def _prepare_inputs_for_body(self, features):
"""Prepare inputs for body.
Args:
features: Map of string to model features. Should contain
"inputs": Transformer inputs. [batch_size, input_length, 1,
hidden_dim].
Returns:
Inputs which will be passed to the model. [batch_size, input_length, 1,
hidden_dim]
"""
return features["inputs"]
def _greedy_infer(self, features, decode_length, use_tpu=False):
"""Fast version of greedy decoding.
Args:
features: an map of string to `Tensor`
decode_length: an integer. How many additional timesteps to decode.
use_tpu: A bool. Whether to build the inference graph for TPU.
Returns:
A dict of decoding results {
"outputs": integer `Tensor` of decoded ids of shape
[batch_size, <= decode_length] if beam_size == 1 or
[batch_size, top_beams, <= decode_length]
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1)
}
Raises:
NotImplementedError: If there are multiple data shards.
"""
# For real-valued modalities use the slow decode path for now.
if (self._target_modality_is_real or
self._hparams.self_attention_type != "dot_product"):
return super(Transformer, self)._greedy_infer(features, decode_length)
with tf.variable_scope(self.name):
if use_tpu:
return self._fast_decode_tpu(features, decode_length)
return self._fast_decode(features, decode_length)
def _beam_decode(self,
features,
decode_length,
beam_size,
top_beams,
alpha,
use_tpu=False):
"""Beam search decoding.
Args:
features: an map of string to `Tensor`
decode_length: an integer. How many additional timesteps to decode.
beam_size: number of beams.
top_beams: an integer. How many of the beams to return.
alpha: Float that controls the length penalty. larger the alpha, stronger
the preference for longer translations.
use_tpu: A bool, whether to do beam decode on TPU.
Returns:
A dict of decoding results {
"outputs": integer `Tensor` of decoded ids of shape
[batch_size, <= decode_length] if beam_size == 1 or
[batch_size, top_beams, <= decode_length]
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1)
}
"""
if (self._hparams.self_attention_type not in [
"dot_product", "dot_product_relative"
]):
# Caching is not guaranteed to work with attention types other than
# dot_product and dot_product_relative.
return self._beam_decode_slow(features, decode_length, beam_size,
top_beams, alpha, use_tpu)
with tf.variable_scope(self.name):
if use_tpu:
return self._fast_decode_tpu(features, decode_length, beam_size,
top_beams, alpha)
return self._fast_decode(features, decode_length, beam_size, top_beams,
alpha)
def _prepare_inputs_for_decode(self, features):
"""Prepare inputs for decoding.
Args:
features: A map of string to model features.
Returns:
Inputs after fixing shape and applying modality.
"""
dp = self._data_parallelism
hparams = self._hparams
inputs = features["inputs"]
# TODO(llion): Clean up this reshaping logic.
inputs = tf.expand_dims(inputs, axis=1)
if len(inputs.shape) < 5:
inputs = tf.expand_dims(inputs, axis=4)
s = common_layers.shape_list(inputs)
inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
# _shard_features called to ensure that the variable names match
inputs = self._shard_features({"inputs": inputs})["inputs"]
input_modality = self._problem_hparams.modality["inputs"]
input_vocab_size = self._problem_hparams.vocab_size["inputs"]
if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor
modality_name = hparams.name.get("inputs",
modalities.get_name(input_modality))(
hparams, input_vocab_size)
with tf.variable_scope(modality_name):
bottom = hparams.bottom.get("inputs",
modalities.get_bottom(input_modality))
inputs = dp(bottom, inputs, hparams, input_vocab_size)
return inputs
def _fast_decode_tpu(self,
features,
decode_length,
beam_size=1,
top_beams=1,
alpha=1.0):
"""Fast decoding.
Implements both greedy and beam search decoding on TPU, uses beam search
iff beam_size > 1, otherwise beam search related arguments are ignored.
Args:
features: A map of string to model features.
decode_length: An integer, how many additional timesteps to decode.
beam_size: An integer, number of beams.
top_beams: An integer, how many of the beams to return.
alpha: A float that controls the length penalty. Larger the alpha,
stronger the preference for longer translations.
Returns:
A dict of decoding results {
"outputs": integer `Tensor` of decoded ids of shape
[batch_size, <= decode_length] if beam_size == 1 or
[batch_size, top_beams, <= decode_length]
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1)
}.
Raises:
NotImplementedError: If there are multiple data shards.
"""
if self._num_datashards != 1:
raise NotImplementedError("Fast decoding only supports a single shard.")
if "targets_segmentation" in features:
raise NotImplementedError(
"Decoding not supported on packed datasets "
" If you want to decode from a dataset, use the non-packed version"
" of the dataset when decoding.")
dp = self._data_parallelism
hparams = self._hparams
target_modality = self._problem_hparams.modality["targets"]
target_vocab_size = self._problem_hparams.vocab_size["targets"]
if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor
if self.has_input:
inputs_shape = common_layers.shape_list(features["inputs"])
if (target_modality == modalities.ModalityType.CLASS_LABEL or
self._problem_hparams.get("regression_targets")):
decode_length = 1
else:
decode_length = (
inputs_shape[1] + features.get("decode_length", decode_length))
batch_size = inputs_shape[0]
inputs = self._prepare_inputs_for_decode(features)
with tf.variable_scope("body"):
encoder_output, encoder_decoder_attention_bias = dp(
self.encode,
inputs,
features["target_space_id"],
hparams,
features=features)
encoder_output = encoder_output[0]
encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
partial_targets = None
else:
# The problem has no inputs.
encoder_output = None
encoder_decoder_attention_bias = None
# Prepare partial targets.
# In either features["inputs"] or features["targets"].
# We force the outputs to begin with these sequences.
partial_targets = features.get("inputs")
if partial_targets is None:
partial_targets = features["targets"]
assert partial_targets is not None
partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
partial_targets = tf.to_int64(partial_targets)
partial_targets_shape = common_layers.shape_list(partial_targets)
partial_targets_length = partial_targets_shape[1]
decode_length = (
partial_targets_length + features.get("decode_length", decode_length))
batch_size = partial_targets_shape[0]
if hparams.pos == "timing":
positional_encoding = common_attention.get_timing_signal_1d(
decode_length + 1, hparams.hidden_size)
elif hparams.pos == "timing_from_features":
positional_encoding = common_attention.add_timing_signals_from_features(
tf.zeros([1, decode_length + 1, hparams.hidden_size]), features,
hparams.position_features)
elif hparams.pos == "emb":
positional_encoding = common_attention.add_positional_embedding(
tf.zeros([1, decode_length + 1, hparams.hidden_size]),
hparams.max_length, "body/targets_positional_embedding", None)
else:
positional_encoding = None
def preprocess_targets(targets, i):
"""Performs preprocessing steps on the targets to prepare for the decoder.
This includes:
- Embedding the ids.
- Flattening to 3D tensor.
- Optionally adding timing signals.
Args:
targets: A tensor, inputs ids to the decoder. [batch_size, 1].
i: An integer, Step number of the decoding loop.
Returns:
A tensor, processed targets [batch_size, 1, hidden_dim].
"""
# _shard_features called to ensure that the variable names match
targets = self._shard_features({"targets": targets})["targets"]
modality_name = hparams.name.get(
"targets",
modalities.get_name(target_modality))(hparams, target_vocab_size)
with tf.variable_scope(modality_name):
bottom = hparams.bottom.get(
"targets", modalities.get_targets_bottom(target_modality))
targets = dp(bottom, targets, hparams, target_vocab_size)[0]
targets = common_layers.flatten4d3d(targets)
# GO embeddings are all zero, this is because transformer_prepare_decoder
# Shifts the targets along by one for the input which pads with zeros.
# If the modality already maps GO to the zero embeddings this is not
# needed.
targets = tf.cond(
tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)
if positional_encoding is not None:
positional_encoding_shape = positional_encoding.shape.as_list()
targets += tf.slice(
positional_encoding, [0, i, 0],
[positional_encoding_shape[0], 1, positional_encoding_shape[2]])
return targets
decoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(decode_length))
if hparams.proximity_bias:
decoder_self_attention_bias += common_attention.attention_bias_proximal(
decode_length)
def symbols_to_logits_tpu_fn(ids, i, cache):
"""Go from ids to logits for next symbol on TPU.
Args:
ids: A tensor, symbol IDs.
i: An integer, step number of the decoding loop. Only used for inference
on TPU.
cache: A dict, containing tensors which are the results of previous
attentions, used for fast decoding.
Returns:
ret: A tensor, computed logits.
cache: A dict, containing tensors which are the results of previous
attentions, used for fast decoding.
"""
ids = ids[:, -1:]
targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
targets = preprocess_targets(targets, i)
bias_shape = decoder_self_attention_bias.shape.as_list()
bias = tf.slice(decoder_self_attention_bias, [0, 0, i, 0],
[bias_shape[0], bias_shape[1], 1, bias_shape[3]])
with tf.variable_scope("body"):
body_outputs = dp(
self.decode,
targets,
cache.get("encoder_output"),
cache.get("encoder_decoder_attention_bias"),
bias,
hparams,
cache,
i,
nonpadding=features_to_nonpadding(features, "targets"))
modality_name = hparams.name.get(
"targets",
modalities.get_name(target_modality))(hparams, target_vocab_size)
with tf.variable_scope(modality_name):
top = hparams.top.get("targets",
modalities.get_top(target_modality))
logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0]
ret = tf.squeeze(logits, axis=[1, 2, 3])
if partial_targets is not None:
# If the position is within the given partial targets, we alter the
# logits to always return those values.
# A faster approach would be to process the partial targets in one
# iteration in order to fill the corresponding parts of the cache.
# This would require broader changes, though.
vocab_size = tf.shape(ret)[1]
def forced_logits():
return tf.one_hot(
tf.tile(
tf.slice(partial_targets, [0, i],
[partial_targets.shape.as_list()[0], 1]),
[beam_size]), vocab_size, 0.0, -1e9)
ret = tf.cond(
tf.less(i, partial_targets_length), forced_logits, lambda: ret)
return ret, cache
eos_id = self.get_decode_end_id() or beam_search.EOS_ID
temperature = features.get("sampling_temp",
getattr(hparams, "sampling_temp", 0.0))
top_k = features.get("sampling_keep_top_k",
getattr(hparams, "sampling_keep_top_k", -1))
ret = fast_decode_tpu(
encoder_output=encoder_output,
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
symbols_to_logits_fn=symbols_to_logits_tpu_fn,
hparams=hparams,
decode_length=decode_length,
vocab_size=target_vocab_size,
init_cache_fn=self._init_cache_fn,
beam_size=beam_size,
top_beams=top_beams,
alpha=alpha,
batch_size=batch_size,
force_decode_length=self._decode_hparams.force_decode_length,
eos_id=eos_id,
sampling_temperature=temperature,
top_k=top_k)
if partial_targets is not None:
if beam_size <= 1 or top_beams <= 1:
ret["outputs"] = ret["outputs"][:, partial_targets_length:]
else:
ret["outputs"] = ret["outputs"][:, :, partial_targets_length:]
return ret
def get_decode_start_id(self):
"""Returns the id of the first decoder input symbol.
The default case maps None to a vector of 0's for transformer. This method
can be overridden to return a different id by a model wanting to use a
different decoder start symbol. The id returned by this method is used to
index the embedding matrix, and retrieve the vector that will be used as the
first input to the decoder
"""
return None
def get_decode_end_id(self):
"""Returns the id of the output symbol that terminates decoding.
This method can be overridden by a different model. The id returned by this
method is used to check if the generation is complete during decoding.
"""
return None
def _fast_decode(self,
features,
decode_length,
beam_size=1,
top_beams=1,
alpha=1.0,
preprocess_targets_method=None):
"""Fast decoding.
Implements both greedy and beam search decoding, uses beam search iff
beam_size > 1, otherwise beam search related arguments are ignored.
Args:
features: a map of string to model features.
decode_length: an integer. How many additional timesteps to decode.
beam_size: number of beams.
top_beams: an integer. How many of the beams to return.
alpha: Float that controls the length penalty. larger the alpha, stronger
the preference for longer translations.
preprocess_targets_method: method used to preprocess targets. If None,
uses method "preprocess_targets" defined inside this method.
Returns:
A dict of decoding results {
"outputs": integer `Tensor` of decoded ids of shape
[batch_size, <= decode_length] if beam_size == 1 or
[batch_size, top_beams, <= decode_length]
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1)
}
Raises:
NotImplementedError: If there are multiple data shards.
"""
if self._num_datashards != 1:
raise NotImplementedError("Fast decoding only supports a single shard.")
dp = self._data_parallelism
hparams = self._hparams
target_modality = self._problem_hparams.modality["targets"]
target_vocab_size = self._problem_hparams.vocab_size["targets"]
if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor
if "targets_segmentation" in features:
raise NotImplementedError(
"Decoding not supported on packed datasets "
" If you want to decode from a dataset, use the non-packed version"
" of the dataset when decoding.")
if self.has_input:
inputs_shape = common_layers.shape_list(features["inputs"])
if (target_modality == modalities.ModalityType.CLASS_LABEL or
self._problem_hparams.get("regression_targets")):
decode_length = 1
else:
decode_length = (
inputs_shape[1] + features.get("decode_length", decode_length))
batch_size = inputs_shape[0]
inputs = self._prepare_inputs_for_decode(features)
with tf.variable_scope("body"):
encoder_output, encoder_decoder_attention_bias = dp(
self.encode,
inputs,
features["target_space_id"],
hparams,
features=features)
encoder_output = encoder_output[0]
encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
partial_targets = features.get("partial_targets")
else:
# The problem has no inputs.
encoder_output = None
encoder_decoder_attention_bias = None
# Prepare partial targets.
# In either features["inputs"] or features["targets"].
# We force the outputs to begin with these sequences.
partial_targets = features.get("inputs")
if partial_targets is None:
partial_targets = features["targets"]
assert partial_targets is not None
if partial_targets is not None:
partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
partial_targets = tf.to_int64(partial_targets)
partial_targets_shape = common_layers.shape_list(partial_targets)
partial_targets_length = partial_targets_shape[1]
decode_length = (
partial_targets_length + features.get("decode_length", decode_length))
batch_size = partial_targets_shape[0]
if hparams.pos == "timing":
positional_encoding = common_attention.get_timing_signal_1d(
decode_length + 1, hparams.hidden_size)
elif hparams.pos == "timing_from_features":
positional_encoding = common_attention.add_timing_signals_from_features(
tf.zeros([1, decode_length, hparams.hidden_size]), features,
hparams.position_features)
elif hparams.pos == "emb":
positional_encoding = common_attention.add_positional_embedding(
tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length,
"body/targets_positional_embedding", None)
else:
positional_encoding = None
def preprocess_targets(targets, i):
"""Performs preprocessing steps on the targets to prepare for the decoder.
This includes:
- Embedding the ids.
- Flattening to 3D tensor.
- Optionally adding timing signals.
Args:
targets: inputs ids to the decoder. [batch_size, 1]
i: scalar, Step number of the decoding loop.
Returns:
Processed targets [batch_size, 1, hidden_dim]
"""
# _shard_features called to ensure that the variable names match
targets = self._shard_features({"targets": targets})["targets"]
modality_name = hparams.name.get(
"targets",
modalities.get_name(target_modality))(hparams, target_vocab_size)
with tf.variable_scope(modality_name):
bottom = hparams.bottom.get(
"targets", modalities.get_targets_bottom(target_modality))
targets = dp(bottom, targets, hparams, target_vocab_size)[0]
targets = common_layers.flatten4d3d(targets)
# GO embeddings are all zero, this is because transformer_prepare_decoder
# Shifts the targets along by one for the input which pads with zeros.
# If the modality already maps GO to the zero embeddings this is not
# needed.
if not self.get_decode_start_id():
targets = tf.cond(
tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)
if positional_encoding is not None:
targets += positional_encoding[:, i:i + 1]
return targets
decoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(decode_length))
if hparams.proximity_bias:
decoder_self_attention_bias += common_attention.attention_bias_proximal(
decode_length)
# Create tensors for encoder-decoder attention history
att_cache = {"attention_history": {}}
num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
if encoder_output is not None:
att_batch_size, enc_seq_length = common_layers.shape_list(
encoder_output)[0:2]
for layer in range(num_layers):
att_cache["attention_history"]["layer_%d" % layer] = tf.zeros(
[att_batch_size, hparams.num_heads, 0, enc_seq_length])
def update_decoder_attention_history(cache):
"""Save attention weights in cache, e.g., for vizualization."""
for k in [x for x in self.attention_weights
if "decoder" in x and "self" not in x and "logits" not in x]:
idx = k.find("layer_")
if idx < 0:
continue
# Get layer number from the string name.
layer_nbr = k[idx + 6:]
idx = 0
while idx + 1 < len(layer_nbr) and layer_nbr[:idx + 1].isdigit():
idx += 1
layer_nbr = "layer_%d" % int(layer_nbr[:idx])
if layer_nbr in cache["attention_history"]:
cache["attention_history"][layer_nbr] = tf.concat(
[cache["attention_history"][layer_nbr],
self.attention_weights[k]],
axis=2)
if not preprocess_targets_method:
preprocess_targets_method = preprocess_targets
def symbols_to_logits_fn(ids, i, cache):
"""Go from ids to logits for next symbol."""
ids = ids[:, -1:]
targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
targets = preprocess_targets_method(targets, i)
bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
with tf.variable_scope("body"):
body_outputs = dp(
self.decode,
targets,
cache.get("encoder_output"),
cache.get("encoder_decoder_attention_bias"),
bias,
hparams,
cache,
nonpadding=features_to_nonpadding(features, "targets"))
update_decoder_attention_history(cache)
modality_name = hparams.name.get(
"targets",
modalities.get_name(target_modality))(hparams, target_vocab_size)
with tf.variable_scope(modality_name):
top = hparams.top.get("targets", modalities.get_top(target_modality))
logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0]
ret = tf.squeeze(logits, axis=[1, 2, 3])
if partial_targets is not None:
# If the position is within the given partial targets, we alter the
# logits to always return those values.
# A faster approach would be to process the partial targets in one
# iteration in order to fill the corresponding parts of the cache.
# This would require broader changes, though.
vocab_size = tf.shape(ret)[1]
def forced_logits():
return tf.one_hot(
tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
-1e9)
ret = tf.cond(
tf.less(i, partial_targets_length), forced_logits, lambda: ret)
return ret, cache
sos_id = self.get_decode_start_id() or 0
eos_id = self.get_decode_end_id() or beam_search.EOS_ID
temperature = features.get("sampling_temp",
getattr(hparams, "sampling_temp", 0.0))
top_k = features.get("sampling_keep_top_k",
getattr(hparams, "sampling_keep_top_k", -1))
ret = fast_decode(
encoder_output=encoder_output,
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
symbols_to_logits_fn=symbols_to_logits_fn,
hparams=hparams,
decode_length=decode_length,
vocab_size=target_vocab_size,
init_cache_fn=self._init_cache_fn,
beam_size=beam_size,
top_beams=top_beams,
alpha=alpha,
batch_size=batch_size,
force_decode_length=self._decode_hparams.force_decode_length,
sos_id=sos_id,
eos_id=eos_id,
sampling_temperature=temperature,
top_k=top_k,
cache=att_cache)
if partial_targets is not None:
if beam_size <= 1 or top_beams <= 1:
ret["outputs"] = ret["outputs"][:, partial_targets_length:]
else:
ret["outputs"] = ret["outputs"][:, :, partial_targets_length:]
return ret
def _init_transformer_cache(cache, hparams, batch_size, attention_init_length,
encoder_output, encoder_decoder_attention_bias,
scope_prefix):
"""Create the initial cache for Transformer fast decoding."""
key_channels = hparams.attention_key_channels or hparams.hidden_size
value_channels = hparams.attention_value_channels or hparams.hidden_size
num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
vars_3d_num_heads = (
hparams.num_heads if hparams.get("attention_variables_3d") else 0)
if cache is None:
cache = {}
cache.update({
"layer_%d" % layer: { # pylint: disable=g-complex-comprehension
"k":
common_attention.split_heads(
tf.zeros([batch_size,
attention_init_length,
key_channels]), hparams.num_heads),
"v":
common_attention.split_heads(
tf.zeros([batch_size,
attention_init_length,
value_channels]), hparams.num_heads),
} for layer in range(num_layers)
})
# If `ffn_layer` is in `["dense_relu_dense" or "conv_hidden_relu"]`, then the
# cache key "f" won't be used, which means that the` shape of cache["f"]`
# won't be changed to
# `[beamsize*batch_size, decode_length, hparams.hidden_size]` and may cause
# error when applying `nest.map reshape function` on it.
if hparams.ffn_layer not in ["dense_relu_dense", "conv_hidden_relu"]:
for layer in range(num_layers):
cache["layer_%d" % layer]["f"] = tf.zeros(
[batch_size, 0, hparams.hidden_size])
if encoder_output is not None:
for layer in range(num_layers):
layer_name = "layer_%d" % layer
with tf.variable_scope(
"%sdecoder/%s/encdec_attention/multihead_attention" %
(scope_prefix, layer_name)):
k_encdec = common_attention.compute_attention_component(
encoder_output,
key_channels,
name="k",
vars_3d_num_heads=vars_3d_num_heads)
k_encdec = common_attention.split_heads(k_encdec, hparams.num_heads)
v_encdec = common_attention.compute_attention_component(
encoder_output,
value_channels,
name="v",
vars_3d_num_heads=vars_3d_num_heads)
v_encdec = common_attention.split_heads(v_encdec, hparams.num_heads)
cache[layer_name]["k_encdec"] = k_encdec
cache[layer_name]["v_encdec"] = v_encdec
cache["encoder_output"] = encoder_output
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
return cache
def fast_decode_tpu(encoder_output,
encoder_decoder_attention_bias,
symbols_to_logits_fn,
hparams,
decode_length,
vocab_size,
init_cache_fn=_init_transformer_cache,
beam_size=1,