This repository has been archived by the owner on Jul 7, 2023. It is now read-only.
/
transformer_layers.py
400 lines (377 loc) · 16.7 KB
/
transformer_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
# coding=utf-8
# Copyright 2022 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Commonly re-used transformer layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import expert_utils
from tensor2tensor.utils import mlperf_log
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
# TODO(lukaszkaiser): remove this function when not needed any more.
def layers():
return common_layers.layers()
def transformer_prepare_encoder(inputs, target_space, hparams, features=None,
type_ids=None, num_types=None,
reuse_target_embedding=tf.AUTO_REUSE):
"""Prepare one shard of the model for the encoder.
Args:
inputs: a Tensor.
target_space: a Tensor.
hparams: run hyperparameters
features: optionally pass the entire features dictionary as well.
This is needed now for "packed" datasets.
type_ids: optional, an int64 Tensor of shape [batch, length] that allows
for adding type embeddings, similar to positional embeddings.
num_types: optional, an int that decides the number of types in type_ids.
reuse_target_embedding: option to reuse variable name in the case that
symbol modalities are reused between inputs/targets.
Returns:
encoder_input: a Tensor, bottom of encoder stack
encoder_self_attention_bias: a bias tensor for use in encoder self-attention
encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
attention
"""
ishape_static = inputs.shape.as_list()
encoder_input = inputs
if features and "inputs_segmentation" in features:
# Packed dataset. Keep the examples from seeing each other.
inputs_segmentation = features["inputs_segmentation"]
inputs_position = features["inputs_position"]
targets_segmentation = features["targets_segmentation"]
if (hasattr(hparams, "unidirectional_encoder") and
hparams.unidirectional_encoder):
tf.logging.info("Using unidirectional encoder")
encoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(
common_layers.shape_list(inputs)[1]))
else:
encoder_self_attention_bias = (
common_attention.attention_bias_same_segment(
inputs_segmentation, inputs_segmentation))
encoder_decoder_attention_bias = (
common_attention.attention_bias_same_segment(targets_segmentation,
inputs_segmentation))
else:
encoder_padding = common_attention.embedding_to_padding(encoder_input)
ignore_padding = common_attention.attention_bias_ignore_padding(
encoder_padding)
if (hasattr(hparams, "unidirectional_encoder") and
hparams.unidirectional_encoder):
tf.logging.info("Using unidirectional encoder")
encoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(
common_layers.shape_list(inputs)[1]))
else:
# Usual case - not a packed dataset.
encoder_self_attention_bias = ignore_padding
encoder_decoder_attention_bias = ignore_padding
inputs_position = None
if hparams.proximity_bias:
encoder_self_attention_bias += common_attention.attention_bias_proximal(
common_layers.shape_list(inputs)[1])
if target_space is not None and hparams.get("use_target_space_embedding",
True):
# Append target_space_id embedding to inputs.
emb_target_space = common_layers.embedding(
target_space,
32,
ishape_static[-1],
name="target_space_embedding",
dtype=hparams.get("activation_dtype", "float32"),
reuse=reuse_target_embedding)
emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
encoder_input += emb_target_space
if hparams.pos == "timing":
if inputs_position is not None:
encoder_input = common_attention.add_timing_signal_1d_given_position(
encoder_input, inputs_position)
else:
encoder_input = common_attention.add_timing_signal_1d(encoder_input)
elif hparams.pos == "timing_from_features":
encoder_input = common_attention.add_timing_signals_from_features(
encoder_input, features, hparams.position_features)
elif hparams.pos == "emb":
encoder_input = common_attention.add_positional_embedding(
encoder_input, hparams.max_length, "inputs_positional_embedding",
inputs_position)
# Add type embeddings
if type_ids is not None:
if not num_types:
raise ValueError("Need to set num_types as well.")
encoder_input = common_attention.add_positional_embedding(
encoder_input, num_types, "inputs_type_embedding", type_ids)
encoder_self_attention_bias = common_layers.cast_like(
encoder_self_attention_bias, encoder_input)
encoder_decoder_attention_bias = common_layers.cast_like(
encoder_decoder_attention_bias, encoder_input)
return (encoder_input, encoder_self_attention_bias,
encoder_decoder_attention_bias)
def transformer_encoder(encoder_input,
encoder_self_attention_bias,
hparams,
name="encoder",
nonpadding=None,
save_weights_to=None,
make_image_summary=True,
losses=None,
attn_bias_for_padding=None):
"""A stack of transformer layers.
Args:
encoder_input: a Tensor
encoder_self_attention_bias: bias Tensor for self-attention
(see common_attention.attention_bias())
hparams: hyperparameters for model
name: a string
nonpadding: optional Tensor with shape [batch_size, encoder_length]
indicating what positions are not padding. This must either be
passed in, which we do for "packed" datasets, or inferred from
encoder_self_attention_bias. The knowledge about padding is used
for pad_remover(efficiency) and to mask out padding in convolutional
layers.
save_weights_to: an optional dictionary to capture attention weights
for visualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
losses: optional list onto which to append extra training losses
attn_bias_for_padding: Padded attention bias in case a unidirectional
encoder is being used where future attention is masked.
Returns:
y: a Tensors
"""
x = encoder_input
attention_dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS,
value=hparams.num_encoder_layers or hparams.num_hidden_layers)
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT,
value=hparams.attention_dropout)
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_ATTENTION_DENSE,
value={
"use_bias": "false",
"num_heads": hparams.num_heads,
"hidden_size": hparams.hidden_size
})
with tf.variable_scope(name):
if nonpadding is not None:
padding = 1.0 - nonpadding
else:
attention_bias = encoder_self_attention_bias
if attn_bias_for_padding is not None:
attention_bias = attn_bias_for_padding
padding = common_attention.attention_bias_to_padding(attention_bias)
nonpadding = 1.0 - padding
pad_remover = None
if hparams.use_pad_remover and not common_layers.is_xla_compiled():
pad_remover = expert_utils.PadRemover(padding)
for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
with tf.variable_scope("self_attention"):
if layer < hparams.get("num_area_layers", 0):
max_area_width = hparams.get("max_area_width", 1)
max_area_height = hparams.get("max_area_height", 1)
memory_height = hparams.get("memory_height", 1)
else:
max_area_width = 1
max_area_height = 1
memory_height = 1
y = common_attention.multihead_attention(
common_layers.layer_preprocess(x, hparams),
None,
encoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"),
hard_attention_k=hparams.get("hard_attention_k", 0),
gumbel_noise_weight=hparams.get("gumbel_noise_weight", 0.0),
max_area_width=max_area_width,
max_area_height=max_area_height,
memory_height=memory_height,
area_key_mode=hparams.get("area_key_mode", "none"),
area_value_mode=hparams.get("area_value_mode", "none"),
training=(hparams.get("mode", tf_estimator.ModeKeys.TRAIN)
== tf_estimator.ModeKeys.TRAIN))
x = common_layers.layer_postprocess(x, y, hparams)
with tf.variable_scope("ffn"):
y = transformer_ffn_layer(
common_layers.layer_preprocess(x, hparams),
hparams,
pad_remover,
conv_padding="SAME",
nonpadding_mask=nonpadding,
losses=losses)
x = common_layers.layer_postprocess(x, y, hparams)
# if normalization is done in layer_preprocess, then it should also be done
# on the output, since the output can grow very large, being the sum of
# a whole stack of unnormalized layer outputs.
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_NORM,
value={"hidden_size": hparams.hidden_size})
return common_layers.layer_preprocess(x, hparams)
def transformer_ffn_layer(x,
hparams,
pad_remover=None,
conv_padding="LEFT",
nonpadding_mask=None,
losses=None,
cache=None,
decode_loop_step=None,
readout_filter_size=0,
layer_collection=None):
"""Feed-forward layer in the transformer.
Args:
x: a Tensor of shape [batch_size, length, hparams.hidden_size]
hparams: hyperparameters for model
pad_remover: an expert_utils.PadRemover object tracking the padding
positions. If provided, when using convolutional settings, the padding
is removed before applying the convolution, and restored afterward. This
can give a significant speedup.
conv_padding: a string - either "LEFT" or "SAME".
nonpadding_mask: an optional Tensor with shape [batch_size, length].
needed for convolutional layers with "SAME" padding.
Contains 1.0 in positions corresponding to nonpadding.
losses: optional list onto which to append extra training losses
cache: dict, containing tensors which are the results of previous
attentions, used for fast decoding.
decode_loop_step: An integer, step number of the decoding loop.
Only used for inference on TPU.
readout_filter_size: if it's greater than 0, then it will be used instead of
filter_size
layer_collection: A tensorflow_kfac.LayerCollection. Only used by the
KFAC optimizer. Default is None.
Returns:
a Tensor of shape [batch_size, length, hparams.hidden_size]
Raises:
ValueError: If losses arg is None, but layer generates extra losses.
"""
ffn_layer = hparams.ffn_layer
relu_dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "relu_dropout_broadcast_dims", "")))
if ffn_layer == "conv_hidden_relu":
# Backwards compatibility
ffn_layer = "dense_relu_dense"
if ffn_layer == "dense_relu_dense":
# In simple convolution mode, use `pad_remover` to speed up processing.
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_FFN_FILTER_DENSE,
value={
"filter_size": hparams.filter_size,
"use_bias": "True",
"activation": mlperf_log.RELU
})
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_FFN_OUTPUT_DENSE,
value={
"hidden_size": hparams.hidden_size,
"use_bias": "True",
})
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_RELU_DROPOUT, value=hparams.relu_dropout)
if pad_remover:
original_shape = common_layers.shape_list(x)
# Collapse `x` across examples, and remove padding positions.
x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0))
x = tf.expand_dims(pad_remover.remove(x), axis=0)
conv_output = common_layers.dense_relu_dense(
x,
hparams.filter_size,
hparams.hidden_size,
dropout=hparams.relu_dropout,
dropout_broadcast_dims=relu_dropout_broadcast_dims,
layer_collection=layer_collection)
if pad_remover:
# Restore `conv_output` to the original shape of `x`, including padding.
conv_output = tf.reshape(
pad_remover.restore(tf.squeeze(conv_output, axis=0)), original_shape)
return conv_output
elif ffn_layer == "conv_relu_conv":
return common_layers.conv_relu_conv(
x,
readout_filter_size or hparams.filter_size,
hparams.hidden_size,
first_kernel_size=hparams.conv_first_kernel,
second_kernel_size=1,
padding=conv_padding,
nonpadding_mask=nonpadding_mask,
dropout=hparams.relu_dropout,
cache=cache,
decode_loop_step=decode_loop_step)
elif ffn_layer == "parameter_attention":
return common_attention.parameter_attention(
x, hparams.parameter_attention_key_channels or hparams.hidden_size,
hparams.parameter_attention_value_channels or hparams.hidden_size,
hparams.hidden_size, readout_filter_size or hparams.filter_size,
hparams.num_heads,
hparams.attention_dropout)
elif ffn_layer == "conv_hidden_relu_with_sepconv":
return common_layers.conv_hidden_relu(
x,
readout_filter_size or hparams.filter_size,
hparams.hidden_size,
kernel_size=(3, 1),
second_kernel_size=(31, 1),
padding="LEFT",
dropout=hparams.relu_dropout)
elif ffn_layer == "sru":
return common_layers.sru(x)
elif ffn_layer == "local_moe_tpu":
overhead = hparams.moe_overhead_eval
if hparams.mode == tf_estimator.ModeKeys.TRAIN:
overhead = hparams.moe_overhead_train
ret, loss = expert_utils.local_moe_tpu(
x,
hparams.filter_size // 2,
hparams.hidden_size,
hparams.moe_num_experts,
overhead=overhead,
loss_coef=hparams.moe_loss_coef)
elif ffn_layer == "local_moe":
overhead = hparams.moe_overhead_eval
if hparams.mode == tf_estimator.ModeKeys.TRAIN:
overhead = hparams.moe_overhead_train
ret, loss = expert_utils.local_moe(
x,
True,
expert_utils.ffn_expert_fn(hparams.hidden_size, [hparams.filter_size],
hparams.hidden_size),
hparams.moe_num_experts,
k=hparams.moe_k,
hparams=hparams)
losses.append(loss)
return ret
else:
assert ffn_layer == "none"
return x