Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
# coding=utf-8
# Copyright 2021 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evolved Transformer model.
This implements the model described in arxiv.org/abs/1901.11117 .
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.models import transformer
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow.compat.v1 as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops import inplace_ops
# pylint: enable=g-direct-tensorflow-import
_CONV_BRANCHES_NAME = "conv_branches"
_CONV_BRANCHES_FIRST_LAYER_NAME = _CONV_BRANCHES_NAME + "_first"
_CONV_BRANCHES_SECOND_LAYER_NAME = _CONV_BRANCHES_NAME + "_second"
_FIRST_ATTEND_TO_ENCODER_NAME = "first_attend_to_encoder"
_SECOND_ATTEND_TO_ENCODER_NAME = "second_attend_to_encoder"
_SIXTEEN_HEAD_ATTENTION_NAME = "16_head_self_attention"
_VANILLA_ATTENTION_NAME = "self_attention"
_DECODER_LEFT_CONV_PADDING = 10
_DECODER_RIGHT_CONV_PADDING = 6
_DECODER_FINAL_CONV_PADDING = 6
def _capped_double_heads(num_heads, cap=16):
"""Calculate the number of heads for the attention layers with more heads.
The number of heads will be twice the normal amount (num_heads), until it
reaches |cap| heads.
Args:
num_heads: the num_heads hparam for the model.
cap: the maximum number of heads |num_heads| will be doubled to.
Returns:
The number of heads for the attention layers that have more heads.
"""
return max(min(num_heads * 2, cap), num_heads)
@registry.register_model
class EvolvedTransformer(transformer.Transformer):
"""The Evolved Transformer from arxiv.org/abs/1901.11117 ."""
def __init__(self, *args, **kwargs):
super(EvolvedTransformer, self).__init__(*args, **kwargs)
self._encoder_function = evolved_transformer_encoder
self._decoder_function = evolved_transformer_decoder
self._init_cache_fn = init_evolved_transformer_cache
# -1 means train all weights.
if self.hparams.get("num_trainable_top_decoder_layers", -1) < 0:
t2t_model.log_info(
"num_trainable_top_decoder_layers is negative so training all weights."
)
elif self.hparams.shared_embedding_and_softmax_weights:
t2t_model.log_info(
"Setting hparams.shared_embedding_and_softmax_weights to False, "
"because hparam.num_trainable_top_decoder_layers is being used.")
# When hparam.num_trainable_top_decoder_layers is set to N >= 0 we will
# freeze (not train) every variable except the N top decoder layers and
# the (pre-)softmax matrix. For any N >= 0 we will freeze the encoder and
# input/target embeddings. This also means we will not share the
# (pre-)softmax matrix with input/target embeddings otherwise they will be
# trained as well.
self.hparams.shared_embedding_and_softmax_weights = False
# If hparams.shared_embedding_and_softmax_weights was previously True,
# then input and target embeddings were being shared.
# To make sure it they embeddings continue to be shared, we need to set
# hparams.shared_embedding to True.
self.hparams.shared_embedding = True
self._init_cache_fn = init_evolved_transformer_cache
def evolved_transformer_encoder(encoder_input,
encoder_self_attention_bias,
hparams,
name="encoder",
nonpadding=None,
save_weights_to=None,
make_image_summary=True,
losses=None,
attn_bias_for_padding=None):
"""Evolved Transformer encoder. See arxiv.org/abs/1901.11117 for more details.
Note: Pad remover is not supported.
Args:
encoder_input: a Tensor.
encoder_self_attention_bias: bias Tensor for self-attention (see
common_attention.attention_bias()).
hparams: hyperparameters for model.
name: a string.
nonpadding: optional Tensor with shape [batch_size, encoder_length]
indicating what positions are not padding. This must either be passed in,
which we do for "packed" datasets, or inferred from
encoder_self_attention_bias. The knowledge about padding is used for
pad_remover(efficiency) and to mask out padding in convolutional layers.
save_weights_to: an optional dictionary to capture attention weights for
visualization; the weights tensor will be appended there under a string
key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
losses: Not used.
attn_bias_for_padding: Padded attention bias in case a unidirectional
encoder is being used where future attention is masked.
Returns:
Tensor encoder output.
"""
del losses
hidden_state = encoder_input
attention_dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
with tf.variable_scope(name):
if nonpadding is not None:
padding = 1.0 - nonpadding
else:
attention_bias = encoder_self_attention_bias
if attn_bias_for_padding is not None:
attention_bias = attn_bias_for_padding
# Only bfloat16 and float32 supported.
float_type = hparams.get("activation_dtype", "float32")
if float_type == "bfloat16":
cast_fn = tf.to_bfloat16
else:
assert float_type == "float32"
cast_fn = tf.to_float
padding = common_attention.attention_bias_to_padding(
attention_bias, cast_fn)
nonpadding = 1.0 - padding
for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
with tf.variable_scope("gated_linear_unit"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
values = common_layers.layers().Dense(
hparams.hidden_size)(hidden_state)
gates = common_layers.layers().Dense(
hparams.hidden_size, activation=tf.nn.sigmoid)(hidden_state)
hidden_state = values * gates
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
with tf.variable_scope("conv_branches"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
# Mask padding from conv layers.
mask = tf.tile(
tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size])
hidden_state *= mask
left_output_dim = int(hparams.hidden_size * 4)
left_state = common_layers.layers().Dense(
left_output_dim, activation=tf.nn.relu)(hidden_state)
left_state = tf.nn.dropout(left_state,
1 - hparams.layer_prepostprocess_dropout)
right_output_dim = int(hparams.hidden_size / 2)
right_state = common_layers.layers().Conv1D(
right_output_dim,
3,
padding="SAME",
name="standard_conv_3x1",
activation=tf.nn.relu)(hidden_state)
right_state = tf.nn.dropout(right_state,
1 - hparams.layer_prepostprocess_dropout)
right_state = tf.pad(
right_state,
[[0, 0], [0, 0], [0, left_output_dim - right_output_dim]],
constant_values=0)
hidden_state = left_state + right_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
# Mask padding from conv layer.
mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, left_output_dim])
hidden_state *= mask
separable_conv_9x1 = common_layers.layers().SeparableConv1D(
right_output_dim, 9, padding="SAME", name="separable_conv_9x1")
hidden_state = separable_conv_9x1(hidden_state)
hidden_state = tf.pad(
hidden_state,
[[0, 0], [0, 0], [0, hparams.hidden_size - right_output_dim]],
constant_values=0)
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
if hparams.get("et_encoder_self_attention", True):
with tf.variable_scope("self_attention"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
hidden_state = common_attention.multihead_attention(
hidden_state,
None,
encoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
with tf.variable_scope("dense_layers"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
hidden_state = common_layers.layers().Dense(
int(hparams.hidden_size * 4), activation=tf.nn.relu)(hidden_state)
hidden_state = tf.nn.dropout(hidden_state,
1 - hparams.layer_prepostprocess_dropout)
hidden_state = common_layers.layers().Dense(
hparams.hidden_size)(hidden_state)
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
# If normalization is done in layer_preprocess, then it should also be done
# on the output, since the output can grow very large, being the sum of
# a whole stack of unnormalized layer outputs.
return common_layers.layer_preprocess(hidden_state, hparams)
def evolved_transformer_decoder(decoder_input,
encoder_output,
decoder_self_attention_bias,
encoder_decoder_attention_bias,
hparams,
cache=None,
decode_loop_step=None,
name="decoder",
nonpadding=None,
save_weights_to=None,
make_image_summary=True,
losses=None):
"""Evolved Transformer decoder. See arxiv.org/abs/1901.11117 for more details.
Args:
decoder_input: a Tensor.
encoder_output: a Tensor.
decoder_self_attention_bias: bias Tensor for self-attention (see
common_attention.attention_bias()).
encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
(see common_attention.attention_bias()).
hparams: hyperparameters for model.
cache: dict, containing tensors which are the results of previous
layers, used for fast decoding.
decode_loop_step: An integer, step number of the decoding loop. Only used
for inference on TPU.
name: a string.
nonpadding: optional Tensor with shape [batch_size, encoder_length]
indicating what positions are not padding. This is used to mask out
padding in convolutional layers. We generally only need this mask for
"packed" datasets, because for ordinary datasets, no padding is ever
followed by nonpadding.
save_weights_to: an optional dictionary to capture attention weights for
visualization; the weights tensor will be appended there under a string
key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
losses: Not supported.
Returns:
Decoder output tensor.
"""
del losses
num_trainable_top_decoder_layers = hparams.get(
"num_trainable_top_decoder_layers", -1) # -1 means train all weights.
if num_trainable_top_decoder_layers >= 0:
encoder_output = tf.stop_gradient(encoder_output)
attention_dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
with tf.variable_scope(name):
hidden_state = decoder_input
num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
for layer in range(num_layers):
if num_trainable_top_decoder_layers == num_layers - layer:
hidden_state = tf.stop_gradient(hidden_state)
layer_name = "layer_%d" % layer
layer_cache = cache[layer_name] if cache is not None else None
with tf.variable_scope(layer_name):
with tf.variable_scope(_SIXTEEN_HEAD_ATTENTION_NAME):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
attention_cache = layer_cache[
_SIXTEEN_HEAD_ATTENTION_NAME] if layer_cache is not None else None
left_state = common_attention.multihead_attention(
hidden_state,
None,
decoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
_capped_double_heads(hparams.num_heads),
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
cache=attention_cache,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
decode_loop_step=decode_loop_step,
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
if encoder_output is not None:
with tf.variable_scope(_FIRST_ATTEND_TO_ENCODER_NAME):
attention_cache = (
layer_cache[_FIRST_ATTEND_TO_ENCODER_NAME]
if layer_cache is not None else None)
right_state = common_attention.multihead_attention(
hidden_state,
encoder_output,
encoder_decoder_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
cache=attention_cache,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
left_state = tf.nn.dropout(left_state,
1 - hparams.layer_prepostprocess_dropout)
right_state = tf.nn.dropout(
right_state, 1 - hparams.layer_prepostprocess_dropout)
hidden_state = residual_state + left_state + right_state
else:
hidden_state = common_layers.layer_postprocess(
residual_state, left_state, hparams)
with tf.variable_scope(_CONV_BRANCHES_NAME):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
if nonpadding is not None:
# Mask padding from conv layers.
mask = tf.tile(
tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size])
hidden_state *= mask
if layer_cache:
if decode_loop_step is None:
hidden_state = layer_cache[
_CONV_BRANCHES_FIRST_LAYER_NAME] = tf.concat(
[
layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME],
hidden_state
],
axis=1)[:, -1 * _DECODER_LEFT_CONV_PADDING - 1:, :]
left_state = hidden_state
right_state = hidden_state[:, _DECODER_LEFT_CONV_PADDING -
_DECODER_RIGHT_CONV_PADDING:, :]
else:
# Inplace update is required for inference on TPU.
# Inplace_ops only supports inplace_update on the first dimension.
tmp = tf.transpose(
layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME], perm=[1, 0, 2])
tmp = tf.expand_dims(tmp, axis=1)
tmp = inplace_ops.alias_inplace_update(
tmp,
decode_loop_step * tf.shape(hidden_state)[1] +
_DECODER_LEFT_CONV_PADDING,
tf.transpose(hidden_state, perm=[1, 0, 2]))
tmp = tf.squeeze(tmp, axis=1)
hidden_state = layer_cache[
_CONV_BRANCHES_FIRST_LAYER_NAME] = tf.transpose(
tmp, perm=[1, 0, 2])
batch_size = hidden_state.shape.as_list()[0]
left_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [
batch_size, _DECODER_LEFT_CONV_PADDING + 1,
hparams.hidden_size
])
right_state = tf.slice(hidden_state, [
0, decode_loop_step + _DECODER_LEFT_CONV_PADDING -
_DECODER_RIGHT_CONV_PADDING, 0
], [
batch_size, _DECODER_RIGHT_CONV_PADDING + 1,
hparams.hidden_size
])
else: # No caching.
left_state = tf.pad(
hidden_state,
paddings=[[0, 0], [_DECODER_LEFT_CONV_PADDING, 0], [0, 0]])
right_state = tf.pad(
hidden_state,
paddings=[[0, 0], [_DECODER_RIGHT_CONV_PADDING, 0], [0, 0]])
left_output_dim = int(hparams.hidden_size * 2)
separable_conv_11x1 = tf.layers.SeparableConv1D(
left_output_dim,
11,
padding="VALID",
name="separable_conv11x1",
activation=tf.nn.relu)
left_state = separable_conv_11x1.apply(left_state)
left_state = tf.nn.dropout(left_state,
1 - hparams.layer_prepostprocess_dropout)
right_output_dim = int(hparams.hidden_size / 2)
separable_conv_7x1_1 = tf.layers.SeparableConv1D(
right_output_dim, 7, padding="VALID", name="separable_conv_7x1_1")
right_state = separable_conv_7x1_1.apply(right_state)
right_state = tf.nn.dropout(right_state,
1 - hparams.layer_prepostprocess_dropout)
right_state = tf.pad(
right_state,
[[0, 0], [0, 0], [0, left_output_dim - right_output_dim]],
constant_values=0)
hidden_state = left_state + right_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
if nonpadding is not None:
# Mask padding from conv layers.
mask = tf.tile(
tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size * 2])
hidden_state *= mask
if layer_cache:
if decode_loop_step is None:
hidden_state = layer_cache[
_CONV_BRANCHES_SECOND_LAYER_NAME] = tf.concat(
[
layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME],
hidden_state
],
axis=1)[:, -1 * _DECODER_FINAL_CONV_PADDING - 1:, :]
else:
# Inplace update is required for inference on TPU.
# Inplace_ops only supports inplace_update on the first dimension.
tmp = tf.transpose(
layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME], perm=[1, 0, 2])
tmp = tf.expand_dims(tmp, axis=1)
tmp = inplace_ops.alias_inplace_update(
tmp, (decode_loop_step + _DECODER_FINAL_CONV_PADDING) *
tf.shape(hidden_state)[1],
tf.transpose(hidden_state, perm=[1, 0, 2]))
tmp = tf.squeeze(tmp, axis=1)
hidden_state = layer_cache[
_CONV_BRANCHES_SECOND_LAYER_NAME] = tf.transpose(
tmp, perm=[1, 0, 2])
batch_size = hidden_state.shape.as_list()[0]
hidden_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [
batch_size, _DECODER_FINAL_CONV_PADDING + 1,
hparams.hidden_size * 2
])
else:
hidden_state = tf.pad(
hidden_state,
paddings=[[0, 0], [_DECODER_FINAL_CONV_PADDING, 0], [0, 0]])
separable_conv_7x1_2 = tf.layers.SeparableConv1D(
hparams.hidden_size,
7,
padding="VALID",
name="separable_conv_7x1_2")
hidden_state = separable_conv_7x1_2.apply(hidden_state)
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
with tf.variable_scope(_VANILLA_ATTENTION_NAME):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
attention_cache = layer_cache[
_VANILLA_ATTENTION_NAME] if layer_cache is not None else None
hidden_state = common_attention.multihead_attention(
hidden_state,
None,
decoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
cache=attention_cache,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
decode_loop_step=decode_loop_step,
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
if encoder_output is not None:
with tf.variable_scope(_SECOND_ATTEND_TO_ENCODER_NAME):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
attention_cache = (
layer_cache[_SECOND_ATTEND_TO_ENCODER_NAME]
if layer_cache is not None else None)
hidden_state = common_attention.multihead_attention(
hidden_state,
encoder_output,
encoder_decoder_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
cache=attention_cache,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
with tf.variable_scope("dense_layers"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
hidden_state = tf.layers.dense(
hidden_state,
int(hparams.hidden_size * 4),
activation=tf.nn.swish)
hidden_state = tf.nn.dropout(hidden_state,
1 - hparams.layer_prepostprocess_dropout)
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
hidden_state = tf.layers.dense(hidden_state, hparams.hidden_size)
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
decoder_output = common_layers.layer_preprocess(hidden_state, hparams)
if num_trainable_top_decoder_layers == 0:
decoder_output = tf.stop_gradient(decoder_output)
return decoder_output
def _add_attend_to_encoder_cache(cache, attention_name, hparams, num_layers,
key_channels, value_channels,
vars_3d_num_heads, scope_prefix,
encoder_output):
"""Add attend-to-encoder layers to cache."""
for layer in range(num_layers):
layer_name = "layer_%d" % layer
with tf.variable_scope("%sdecoder/%s/%s/multihead_attention" %
(scope_prefix, layer_name, attention_name)):
k_encdec = common_attention.compute_attention_component(
encoder_output,
key_channels,
name="k",
vars_3d_num_heads=vars_3d_num_heads)
k_encdec = common_attention.split_heads(k_encdec, hparams.num_heads)
v_encdec = common_attention.compute_attention_component(
encoder_output,
value_channels,
name="v",
vars_3d_num_heads=vars_3d_num_heads)
v_encdec = common_attention.split_heads(v_encdec, hparams.num_heads)
cache[layer_name][attention_name] = {
"k_encdec": k_encdec,
"v_encdec": v_encdec
}
return cache
def init_evolved_transformer_cache(cache, hparams, batch_size,
attention_init_length, encoder_output,
encoder_decoder_attention_bias,
scope_prefix):
"""Create the initial cache for Evolved Transformer fast decoding."""
key_channels = hparams.attention_key_channels or hparams.hidden_size
value_channels = hparams.attention_value_channels or hparams.hidden_size
num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
vars_3d_num_heads = (
hparams.num_heads if hparams.get("attention_variables_3d") else 0)
# Add self-attentions.
if cache is None:
cache = {}
cache.update({
"layer_%d" % layer: { # pylint: disable=g-complex-comprehension
_SIXTEEN_HEAD_ATTENTION_NAME: {
"k":
common_attention.split_heads(
tf.zeros(
[batch_size, attention_init_length, key_channels]),
_capped_double_heads(hparams.num_heads)),
"v":
common_attention.split_heads(
tf.zeros(
[batch_size, attention_init_length, value_channels]),
_capped_double_heads(hparams.num_heads)),
},
_VANILLA_ATTENTION_NAME: {
"k":
common_attention.split_heads(
tf.zeros(
[batch_size, attention_init_length, key_channels]),
hparams.num_heads),
"v":
common_attention.split_heads(
tf.zeros(
[batch_size, attention_init_length, value_channels]),
hparams.num_heads),
}
} for layer in range(num_layers)
})
# Add branched layers. Pad with additional zeros for causal convolution.
for layer in range(num_layers):
cache["layer_%d" % layer][_CONV_BRANCHES_FIRST_LAYER_NAME] = tf.zeros([
batch_size, attention_init_length + _DECODER_LEFT_CONV_PADDING,
hparams.hidden_size
])
cache["layer_%d" % layer][_CONV_BRANCHES_SECOND_LAYER_NAME] = tf.zeros([
batch_size, attention_init_length + _DECODER_FINAL_CONV_PADDING,
hparams.hidden_size * 2
])
# Add encoder embedding attentions.
if encoder_output is not None:
cache = _add_attend_to_encoder_cache(
cache=cache,
attention_name=_FIRST_ATTEND_TO_ENCODER_NAME,
hparams=hparams,
num_layers=num_layers,
key_channels=key_channels,
value_channels=value_channels,
vars_3d_num_heads=vars_3d_num_heads,
scope_prefix=scope_prefix,
encoder_output=encoder_output)
cache = _add_attend_to_encoder_cache(
cache=cache,
attention_name=_SECOND_ATTEND_TO_ENCODER_NAME,
hparams=hparams,
num_layers=num_layers,
key_channels=key_channels,
value_channels=value_channels,
vars_3d_num_heads=vars_3d_num_heads,
scope_prefix=scope_prefix,
encoder_output=encoder_output)
cache["encoder_output"] = encoder_output
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
return cache
# TODO(davidso): Update optimizer, learning rate, and decay to match paper.
def add_evolved_transformer_hparams(hparams):
"""Add Evolved Transformer hparams.
Note: These are for the Adam optimizer, not the Adafactor optimizer used in
the paper.
Args:
hparams: Current hparams.
Returns:
hparams updated with Evolved Transformer values.
"""
# Evolved Transformer "layers" are twice as deep as Transformer, so roughly
# halve the number that we use. These numbers are taken from
# arxiv.org/abs/1901.11117 .
hparams.num_encoder_layers = 3
hparams.num_decoder_layers = 4
# Learning rate and decay scheme that mimics the transformer Adam config,
# but with cosine decay instead of rsqrt.
hparams.learning_rate_constant /= hparams.learning_rate_warmup_steps ** 0.5
hparams.learning_rate_schedule = (
"constant*linear_warmup*single_cycle_cos_decay*rsqrt_hidden_size")
return hparams
@registry.register_hparams
def evolved_transformer_tiny():
"""Base parameters for Evolved Transformer model."""
hparams = add_evolved_transformer_hparams(transformer.transformer_tiny())
hparams.learning_rate_schedule = (
"constant*single_cycle_cos_decay")
return hparams
@registry.register_hparams
def evolved_transformer_base():
"""Base parameters for Evolved Transformer model."""
return add_evolved_transformer_hparams(transformer.transformer_base())
@registry.register_hparams
def evolved_transformer_big():
"""Big parameters for Evolved Transformer model on WMT."""
return add_evolved_transformer_hparams(transformer.transformer_big())
@registry.register_hparams
def evolved_transformer_deep():
"""Deep parameters for Evolved Transformer model on WMT."""
hparams = add_evolved_transformer_hparams(transformer.transformer_big())
hparams.num_encoder_layers = 9
hparams.num_decoder_layers = 10
hparams.hidden_size = 640
return hparams
@registry.register_hparams
def evolved_transformer_base_tpu():
"""Base parameters for Evolved Transformer model on TPU."""
hparams = add_evolved_transformer_hparams(transformer.transformer_tpu())
hparams.learning_rate_constant = 1 / hparams.learning_rate_warmup_steps ** 0.5
hparams.learning_rate_schedule = (
"constant*single_cycle_cos_decay")
return hparams
@registry.register_hparams
def evolved_transformer_big_tpu():
"""Big parameters for Evolved Transformer model on TPU."""
hparams = add_evolved_transformer_hparams(transformer.transformer_big_tpu())
hparams.learning_rate_constant = 1 / hparams.learning_rate_warmup_steps ** 0.5
hparams.learning_rate_schedule = (
"constant*single_cycle_cos_decay")
return hparams
@registry.register_hparams
def evolved_transformer_tpu_basic():
"""Basic Seq2Seq TPU hyper-parameters."""
hparams = transformer.transformer_big_tpu()
hparams.add_hparam("print_vars", False)
hparams.batch_size = 8192
hparams.max_length = 256
# N < 0 means all weights in the model are trainable.
# N >= 0 means all weights are frozen except N top decoder layers +
# (pre-)softmax matrix (that projects from hidden size to vocab size).
hparams.add_hparam("num_trainable_top_decoder_layers", -1)
return hparams