Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f5c9b17

Browse files
nshazeerRyan Sepassi
authored andcommitted
Added options for configuring different types of processing on layer input and layer output (normalization, dropout, residuals). These settings are configured by common_hparams, and should work across many models. Normalization on layer input instead of after the residual seems to help in learning deep networks. This change breaks current model checkpoints.
PiperOrigin-RevId: 164630450
1 parent 9c54d86 commit f5c9b17

File tree

9 files changed

+307
-220
lines changed

9 files changed

+307
-220
lines changed

tensor2tensor/layers/common_hparams.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,23 @@ def basic_params1():
6969
sampling_method="argmax", # "argmax" or "random"
7070
problem_choice="adaptive", # "uniform", "adaptive", "distributed"
7171
multiply_embedding_mode="sqrt_depth",
72+
# Sequences of operations to perform on layer input and layer output.
73+
# Used by common_layers.layer_preprocess, common_layers.layer_postprocess
74+
# Each character repsesnts an operation:
75+
# d: apply dropout
76+
# n: apply normalization (see norm_type and norm_epsilon)
77+
# a: add layer input (residual connection - only during postprocess)
78+
# TODO(noam): The current settings ("", "dan") are the published version
79+
# of the transformer. ("n", "da") seems better for harder-to-learn
80+
# models, so it should probably be the default.
81+
layer_preprocess_sequence="",
82+
layer_postprocess_sequence="dan",
83+
# dropout rate to use during layer_preprocess and layer_postprocess
84+
layer_prepostprocess_dropout=0.1,
85+
# What type of normalization to use
7286
norm_type="none", # "batch", layer", "noam", "none".
73-
layer_norm_epsilon=1e-6,
87+
# epsilon parameter to normalization function
88+
norm_epsilon=1e-6,
7489
symbol_modality_num_shards=16,
7590
# setting the max length in a minibatch. 0 means default behavior,
7691
# max_length = hparams.batch_size * length_multiplier

tensor2tensor/layers/common_layers.py

Lines changed: 109 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -462,64 +462,136 @@ def layer_norm(x, filters=None, epsilon=1e-6, name=None, reuse=None):
462462
return result
463463

464464

465-
def noam_norm(x, name=None):
465+
def noam_norm(x, epsilon=1.0, name=None):
466466
"""One version of layer normalization."""
467467
with tf.name_scope(name, default_name="noam_norm", values=[x]):
468468
shape = x.get_shape()
469469
ndims = len(shape)
470-
return (tf.nn.l2_normalize(x, ndims - 1, epsilon=1.0) *
470+
return (tf.nn.l2_normalize(x, ndims - 1, epsilon=epsilon) *
471471
tf.sqrt(tf.to_float(shape[-1])))
472472

473473

474-
def get_norm(norm_type):
475-
"""Get the normalizer function."""
474+
def apply_norm(x, norm_type, depth, epsilon):
475+
"""Apply Normalization."""
476476
if norm_type == "layer":
477-
return lambda x, name, filters=None, epsilon=1e-6: layer_norm( # pylint: disable=g-long-lambda
478-
x, filters=filters, epsilon=epsilon, name=name)
477+
return layer_norm(x, filters=depth, epsilon=epsilon)
479478
if norm_type == "batch":
480-
return tf.layers.batch_normalization
479+
return tf.layers.batch_normalization(x, epsilon=epsilon)
481480
if norm_type == "noam":
482-
return noam_norm
481+
return noam_norm(x, epsilon)
483482
if norm_type == "none":
484-
return lambda x, name: x
483+
return x
485484
raise ValueError("Parameter normalizer_fn must be one of: 'layer', 'batch',"
486485
"'noam', 'none'.")
487486

488487

489-
def residual_fn(x,
490-
y,
491-
norm_type,
492-
residual_dropout,
493-
filters=None,
494-
epsilon=1e-16,
495-
name=None,
496-
reuse=None):
497-
"""Returns a function for combining layer input and layer output.
488+
def layer_prepostprocess(previous_value,
489+
x,
490+
sequence,
491+
dropout_rate,
492+
norm_type,
493+
depth,
494+
epsilon,
495+
name):
496+
"""Apply a sequence of functions to the input or output of a layer.
497+
498+
The sequence is specified as a string which may contain the following
499+
characters:
500+
a: add previous_value
501+
n: apply normalization
502+
d: apply dropout
498503
499-
The returned function on x (layer input) and y (layer output) computes:
500-
norm_function(x + dropout(y))
504+
For example, if sequence=="dna", then the output is
505+
previous_value + normalize(dropout(x))
501506
502507
Args:
503-
x: tensor, input layer
504-
y: tensor, output layer
505-
norm_type: string, type of normalizer function
506-
residual_dropout: integer, dropout value for residual connection
507-
filters: integer, dimension for layer norm, optional
508-
epsilon: integer, value of layer norm epsilon
509-
name: string, name
510-
reuse: bool, whether to reuse
508+
previous_value: A Tensor, to be added as a residual connection ('a')
509+
x: A Tensor to be transformed.
510+
sequence: a string.
511+
dropout_rate: a float
512+
norm_type: a string (see apply_norm())
513+
depth: an integer (size of last dimension of x).
514+
epsilon: a float (parameter for normalization)
515+
name: a string
511516
512517
Returns:
513-
residual layer output with applied norm_fn.
518+
a Tensor
514519
"""
515-
with tf.variable_scope(
516-
name, default_name="residual", values=[x, y], reuse=reuse):
517-
norm_fn = get_norm(norm_type)
518-
res = x + tf.nn.dropout(y, 1.0 - residual_dropout)
519-
if norm_type == "layer":
520-
return norm_fn(res, filters=filters, epsilon=epsilon, name=norm_type)
521-
else:
522-
return norm_fn(res, name=norm_type)
520+
with tf.variable_scope(name):
521+
for c in sequence:
522+
if c == "a":
523+
x += previous_value
524+
elif c == "n":
525+
x = apply_norm(x, norm_type, depth, epsilon)
526+
else:
527+
assert c == "d", ("Unknown sequence step %s" % c)
528+
x = tf.nn.dropout(x, 1.0 - dropout_rate)
529+
return x
530+
531+
532+
def layer_preprocess(layer_input, hparams):
533+
"""Apply layer preprocessing.
534+
535+
See layer_prepostprocess() for details.
536+
537+
A hyperparemeters object is passed for convenience. The hyperparameters
538+
that may be used are:
539+
540+
layer_preprocess_sequence
541+
layer_prepostprocess_dropout
542+
norm_type
543+
hidden_size
544+
norm_epsilon
545+
546+
Args:
547+
layer_input: a Tensor
548+
hparams: a hyperparameters object.
549+
550+
Returns:
551+
a Tensor
552+
"""
553+
assert "a" not in hparams.layer_preprocess_sequence, (
554+
"No residual connections allowed in hparams.layer_preprocess_sequence")
555+
return layer_prepostprocess(
556+
None, layer_input,
557+
sequence=hparams.layer_preprocess_sequence,
558+
dropout_rate=hparams.layer_prepostprocess_dropout,
559+
norm_type=hparams.norm_type,
560+
depth=hparams.hidden_size,
561+
epsilon=hparams.norm_epsilon,
562+
name="layer_prepostprocess")
563+
564+
565+
def layer_postprocess(layer_input, layer_output, hparams):
566+
"""Apply layer postprocessing.
567+
568+
See layer_prepostprocess() for details.
569+
570+
A hyperparemeters object is passed for convenience. The hyperparameters
571+
that may be used are:
572+
573+
layer_postprocess_sequence
574+
layer_prepostprocess_dropout
575+
norm_type
576+
hidden_size
577+
norm_epsilon
578+
579+
Args:
580+
layer_input: a Tensor
581+
layer_output: a Tensor
582+
hparams: a hyperparameters object.
583+
584+
Returns:
585+
a Tensor
586+
"""
587+
return layer_prepostprocess(
588+
layer_input, layer_output,
589+
sequence=hparams.layer_postprocess_sequence,
590+
dropout_rate=hparams.layer_prepostprocess_dropout,
591+
norm_type=hparams.norm_type,
592+
depth=hparams.hidden_size,
593+
epsilon=hparams.norm_epsilon,
594+
name="layer_postprocess")
523595

524596

525597
def conv_block_internal(conv_fn,

tensor2tensor/layers/common_layers_test.py

Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -303,74 +303,43 @@ def testDeconvStride2MultiStep(self):
303303
actual = session.run(a)
304304
self.assertEqual(actual.shape, (5, 32, 1, 16))
305305

306-
def testGetNormLayerFn(self):
307-
norm_type = "layer"
306+
def testApplyNormLayer(self):
308307
with self.test_session() as session:
309-
a = common_layers.get_norm(norm_type)
310308
x1 = np.random.rand(5, 2, 1, 11)
311-
x2 = a(tf.constant(x1, dtype=tf.float32), name="layer", filters=11)
309+
x2 = common_layers.apply_norm(
310+
tf.constant(x1, dtype=tf.float32), "layer", depth=11, epsilon=1e-6)
312311
session.run(tf.global_variables_initializer())
313312
actual = session.run(x2)
314313
self.assertEqual(actual.shape, (5, 2, 1, 11))
315314

316-
def testGetNormNoamFn(self):
317-
norm_type = "noam"
315+
def testApplyNormNoam(self):
318316
with self.test_session() as session:
319-
a = common_layers.get_norm(norm_type)
320317
x1 = np.random.rand(5, 2, 1, 11)
321-
x2 = a(tf.constant(x1, dtype=tf.float32), name="noam")
318+
x2 = common_layers.apply_norm(
319+
tf.constant(x1, dtype=tf.float32), "noam", depth=11, epsilon=1e-6)
322320
session.run(tf.global_variables_initializer())
323321
actual = session.run(x2)
324322
self.assertEqual(actual.shape, (5, 2, 1, 11))
325323

326-
def testGetNormBatchFn(self):
327-
norm_type = "batch"
324+
def testApplyNormBatch(self):
328325
with self.test_session() as session:
329-
a = common_layers.get_norm(norm_type)
330326
x1 = np.random.rand(5, 2, 1, 11)
331-
x2 = a(tf.constant(x1, dtype=tf.float32), name="batch")
327+
x2 = common_layers.apply_norm(
328+
tf.constant(x1, dtype=tf.float32), "batch", depth=11, epsilon=1e-6)
332329
session.run(tf.global_variables_initializer())
333330
actual = session.run(x2)
334331
self.assertEqual(actual.shape, (5, 2, 1, 11))
335332

336-
def testGetNormNoneFn(self):
337-
norm_type = "none"
333+
def testApplyNormNone(self):
338334
with self.test_session() as session:
339-
a = common_layers.get_norm(norm_type)
340335
x1 = np.random.rand(5, 2, 1, 11)
341-
x2 = a(tf.constant(x1, dtype=tf.float32), name="none")
336+
x2 = common_layers.apply_norm(
337+
tf.constant(x1, dtype=tf.float32), "none", depth=11, epsilon=1e-6)
342338
session.run(tf.global_variables_initializer())
343339
actual = session.run(x2)
344340
self.assertEqual(actual.shape, (5, 2, 1, 11))
345341
self.assertAllClose(actual, x1, atol=1e-03)
346342

347-
def testResidualFn(self):
348-
norm_type = "batch"
349-
with self.test_session() as session:
350-
x1 = np.random.rand(5, 2, 1, 11)
351-
x2 = np.random.rand(5, 2, 1, 11)
352-
x3 = common_layers.residual_fn(
353-
tf.constant(x1, dtype=tf.float32),
354-
tf.constant(x2, dtype=tf.float32), norm_type, 0.1)
355-
session.run(tf.global_variables_initializer())
356-
actual = session.run(x3)
357-
self.assertEqual(actual.shape, (5, 2, 1, 11))
358-
359-
def testResidualFnWithLayerNorm(self):
360-
norm_type = "layer"
361-
with self.test_session() as session:
362-
x1 = np.random.rand(5, 2, 1, 11)
363-
x2 = np.random.rand(5, 2, 1, 11)
364-
x3 = common_layers.residual_fn(
365-
tf.constant(x1, dtype=tf.float32),
366-
tf.constant(x2, dtype=tf.float32),
367-
norm_type,
368-
0.1,
369-
epsilon=0.1)
370-
session.run(tf.global_variables_initializer())
371-
actual = session.run(x3)
372-
self.assertEqual(actual.shape, (5, 2, 1, 11))
373-
374343
def testGlobalPool1d(self):
375344
x1 = np.random.rand(5, 4, 11)
376345
no_mask = np.ones((5, 4))

tensor2tensor/models/attention_lm.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,10 @@ def model_fn_body(self, features):
5151
(decoder_input, decoder_self_attention_bias) = attention_lm_prepare_decoder(
5252
targets, hparams)
5353

54-
def residual_fn(x, y):
55-
return common_layers.layer_norm(x + tf.nn.dropout(
56-
y, 1.0 - hparams.residual_dropout))
57-
58-
decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout)
59-
decoder_output = attention_lm_decoder(decoder_input, residual_fn,
60-
decoder_self_attention_bias, hparams)
54+
decoder_input = tf.nn.dropout(
55+
decoder_input, 1.0 - hparams.layer_prepostprocess_dropout)
56+
decoder_output = attention_lm_decoder(
57+
decoder_input, decoder_self_attention_bias, hparams)
6158
decoder_output = tf.expand_dims(decoder_output, 2)
6259

6360
return decoder_output
@@ -84,15 +81,13 @@ def attention_lm_prepare_decoder(targets, hparams):
8481

8582

8683
def attention_lm_decoder(decoder_input,
87-
residual_fn,
8884
decoder_self_attention_bias,
8985
hparams,
9086
name="decoder"):
9187
"""A stack of attention_lm layers.
9288
9389
Args:
9490
decoder_input: a Tensor
95-
residual_fn: a function from (layer_input, layer_output) -> combined_output
9691
decoder_self_attention_bias: bias Tensor for self-attention
9792
(see common_attention.attention_bias())
9893
hparams: hyperparameters for model
@@ -105,25 +100,25 @@ def attention_lm_decoder(decoder_input,
105100
with tf.variable_scope(name):
106101
for layer in xrange(hparams.num_hidden_layers):
107102
with tf.variable_scope("layer_%d" % layer):
108-
x = residual_fn(
109-
x,
110-
common_attention.multihead_attention(
111-
x,
112-
None,
113-
decoder_self_attention_bias,
114-
hparams.attention_key_channels or hparams.hidden_size,
115-
hparams.attention_value_channels or hparams.hidden_size,
116-
hparams.hidden_size,
117-
hparams.num_heads,
118-
hparams.attention_dropout,
119-
name="decoder_self_attention"))
120-
x = residual_fn(x,
121-
common_layers.conv_hidden_relu(
122-
x,
123-
hparams.filter_size,
124-
hparams.hidden_size,
125-
dropout=hparams.relu_dropout))
126-
return x
103+
with tf.variable_scope("self_attention"):
104+
y = common_attention.multihead_attention(
105+
common_layers.layer_preprocess(x, hparams),
106+
None,
107+
decoder_self_attention_bias,
108+
hparams.attention_key_channels or hparams.hidden_size,
109+
hparams.attention_value_channels or hparams.hidden_size,
110+
hparams.hidden_size,
111+
hparams.num_heads,
112+
hparams.attention_dropout)
113+
x = common_layers.layer_postprocess(x, y, hparams)
114+
with tf.variable_scope("ffn"):
115+
y = common_layers.conv_hidden_relu(
116+
common_layers.layer_preprocess(x, hparams),
117+
hparams.filter_size,
118+
hparams.hidden_size,
119+
dropout=hparams.relu_dropout)
120+
x = common_layers.layer_postprocess(x, y, hparams)
121+
return common_layers.layer_preprocess(x, hparams)
127122

128123

129124
@registry.register_hparams
@@ -158,7 +153,6 @@ def attention_lm_base():
158153
# when not in training mode.
159154
hparams.add_hparam("attention_dropout", 0.0)
160155
hparams.add_hparam("relu_dropout", 0.0)
161-
hparams.add_hparam("residual_dropout", 0.1)
162156
hparams.add_hparam("pos", "timing") # timing, none
163157
return hparams
164158

@@ -178,5 +172,5 @@ def attention_lm_small():
178172
hparams.num_hidden_layers = 4
179173
hparams.hidden_size = 512
180174
hparams.filter_size = 2048
181-
hparams.residual_dropout = 0.5
175+
hparams.layer_prepostprocess_dropout = 0.5
182176
return hparams

0 commit comments

Comments
 (0)