Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
12 contributors

Users who have contributed to this file

@lukaszkaiser @afrozenator @MostafaDehghani @artitw @etragas-fathom @PlayDeep @rsepassi @rllin-fathom @lgeiger @cfiken @cbockman @jurasofish
1482 lines (1193 sloc) 51.4 KB
# coding=utf-8
# Copyright 2019 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for Universal Transformer.
The Universal Transformer is based on the popular encoder-decoder architecture.
However, as opposed to a fixed stack of distinct layers (as is usually the case
for most popular neural sequence models), the Universal Transformer is
recurrent "in depth", and repeatedly applies the same series of functions with
the same parameters to all elements of the sequence in parallel, revising their
representations with every step. The encoder and decoder have the same
recurrent structure, but the decoder additionally consumes the final encoder
representations for each position. Like the Transformer, the Universal
Transformer is autoregressive. Trained using teacher-forcing, at generation
time it produces its output one position at a time, with the decoder consuming
the previously produced output positions.
Given an input sequence of length m, we start with a matrix whose rows are the
d-dimensional embeddings of the symbols at each position of the sequence.
The Universal Transformer then iteratively computes representation of the input
at each step by applying the multiheaded dot-product self-attention mechanism,
followed by a recurrent transition function. We also add residual connections
around each of these function blocks and apply dropout and layer normalization.
The recurrent transition function in fact controls how steps communicate with
each other in depth. For instance, the recurrent transition, can be a simple
identity function which passes the output of a step as the input to next step.
Or it can be an LSTM (flipped vertically) next to the transformer which
controls how state of the model changes in depth.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import functools
from six.moves import range # pylint: disable=redefined-builtin
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.models import transformer
from tensor2tensor.utils import expert_utils
import tensorflow as tf
def universal_transformer_encoder(encoder_input,
encoder_self_attention_bias,
hparams,
name="encoder",
nonpadding=None,
save_weights_to=None,
make_image_summary=True):
"""Universal Transformer encoder function.
Prepares all the arguments and the inputs and passes it to a
universal_transformer_layer to encode the encoder_input.
Args:
encoder_input: a Tensor
encoder_self_attention_bias: bias Tensor for self-attention
(see common_attention.attention_bias())
hparams: hyperparameters for model
name: a string
nonpadding: optional Tensor with shape [batch_size, encoder_length]
indicating what positions are not padding. This must either be
passed in, which we do for "packed" datasets, or inferred from
encoder_self_attention_bias. The knowledge about padding is used
for pad_remover(efficiency) and to mask out padding in convoltutional
layers.
save_weights_to: an optional dictionary to capture attention weights
for vizualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
Returns:
y: a Tensors as the output of the encoder
extra_output: which can be used to pass extra information to the body
"""
x = encoder_input
attention_dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
with tf.variable_scope(name):
if nonpadding is not None:
padding = 1.0 - nonpadding
else:
padding = common_attention.attention_bias_to_padding(
encoder_self_attention_bias)
nonpadding = 1.0 - padding
pad_remover = None
if hparams.use_pad_remover and not common_layers.is_xla_compiled():
pad_remover = expert_utils.PadRemover(padding)
ffn_unit = functools.partial(
transformer_encoder_ffn_unit,
hparams=hparams,
nonpadding_mask=nonpadding,
pad_remover=pad_remover)
attention_unit = functools.partial(
transformer_encoder_attention_unit,
hparams=hparams,
encoder_self_attention_bias=encoder_self_attention_bias,
attention_dropout_broadcast_dims=attention_dropout_broadcast_dims,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary)
x, extra_output = universal_transformer_layer(
x, hparams, ffn_unit, attention_unit, pad_remover=pad_remover)
return common_layers.layer_preprocess(x, hparams), extra_output
def universal_transformer_decoder(decoder_input,
encoder_output,
decoder_self_attention_bias,
encoder_decoder_attention_bias,
hparams,
name="decoder",
nonpadding=None,
save_weights_to=None,
make_image_summary=True):
"""Universal Transformer decoder function.
Prepares all the arguments and the inputs and passes it to a
core_universal_transformer_layer to decoder.
Args:
decoder_input: a Tensor
encoder_output: a Tensor
decoder_self_attention_bias: bias Tensor for self-attention
(see common_attention.attention_bias())
encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
(see common_attention.attention_bias())
hparams: hyperparameters for model
name: a string
nonpadding: optional Tensor with shape [batch_size, encoder_length]
indicating what positions are not padding. This is used
to mask out padding in convoltutional layers. We generally only
need this mask for "packed" datasets, because for ordinary datasets,
no padding is ever followed by nonpadding.
save_weights_to: an optional dictionary to capture attention weights
for vizualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
Returns:
y: the output Tensors
extra_output: which can be used to pass extra information to the body
"""
x = decoder_input
attention_dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
with tf.variable_scope(name):
ffn_unit = functools.partial(
transformer_decoder_ffn_unit,
hparams=hparams,
nonpadding_mask=nonpadding)
attention_unit = functools.partial(
transformer_decoder_attention_unit,
hparams=hparams,
encoder_output=encoder_output,
decoder_self_attention_bias=decoder_self_attention_bias,
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
attention_dropout_broadcast_dims=attention_dropout_broadcast_dims,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary)
x, extra_output = universal_transformer_layer(
x, hparams, ffn_unit, attention_unit)
return common_layers.layer_preprocess(x, hparams), extra_output
def universal_transformer_layer(x,
hparams,
ffn_unit,
attention_unit,
pad_remover=None):
"""Core function applying the universal transformer layer.
Args:
x: input
hparams: model hyper-parameters
ffn_unit: feed-forward unit
attention_unit: multi-head attention unit
pad_remover: to mask out padding in convolutional layers (efficiency).
Returns:
the output tensor, extra output (can be memory, ponder time, etc.)
Raises:
ValueError: Unknown recurrence type
"""
def add_vanilla_transformer_layer(x, num_layers, name):
"""Passes the input through num_layers of vanilla transformer layers.
Args:
x: input
num_layers: number of layers
name: string, prefix of layer names
Returns:
output of vanilla_transformer_layer
"""
if hparams.add_position_timing_signal:
# In case of add_position_timing_signal=true, we set hparams.pos=None
# and add position timing signal at the beginning of each step, so for
# the vanilla transformer, we need to add timing signal here.
x = common_attention.add_timing_signal_1d(x)
for layer in range(num_layers):
with tf.variable_scope(name + "layer_%d" % layer):
x = ffn_unit(attention_unit(x))
return x
with tf.variable_scope("universal_transformer_%s" % hparams.recurrence_type):
if (hparams.mix_with_transformer and
"before_ut" in hparams.mix_with_transformer):
x = add_vanilla_transformer_layer(x, hparams.num_mixedin_layers,
"before_ut_")
if hparams.recurrence_type == "act":
output, extra_output = universal_transformer_act(
x, hparams, ffn_unit, attention_unit)
else: # for all the other recurrency types with fixed number of steps
ut_function, initializer = get_ut_layer(x, hparams, ffn_unit,
attention_unit, pad_remover)
output, _, extra_output = tf.foldl(
ut_function, tf.range(hparams.num_rec_steps),
initializer=initializer)
# Right now, this is only possible when the transition function is an lstm
if (hparams.recurrence_type == "lstm" and
hparams.get("use_memory_as_final_state", False)):
output = extra_output
if (hparams.mix_with_transformer and
"after_ut" in hparams.mix_with_transformer):
output = add_vanilla_transformer_layer(output, hparams.num_mixedin_layers,
"after_ut_")
return output, extra_output
def get_ut_layer(x,
hparams,
ffn_unit,
attention_unit,
pad_remover=None):
"""Provides the function that is used in universal transforemr steps.
Args:
x: input
hparams: model hyper-parameters
ffn_unit: feed-forward unit
attention_unit: multi-head attention unit
pad_remover: to mask out padding in convolutional layers (efficiency).
Returns:
ut_function and the ut_initializer
Raises:
ValueError: Unknown recurrence type
"""
if hparams.recurrence_type == "basic":
ut_initializer = (x, x, x) # (state, input, memory)
ut_function = functools.partial(
universal_transformer_basic,
hparams=hparams,
ffn_unit=ffn_unit,
attention_unit=attention_unit)
elif hparams.recurrence_type == "highway":
ut_initializer = (x, x, x) # (state, input, memory)
ut_function = functools.partial(
universal_transformer_highway,
hparams=hparams,
ffn_unit=ffn_unit,
attention_unit=attention_unit,
pad_remover=pad_remover)
elif hparams.recurrence_type == "skip":
ut_initializer = (x, x, x) # (state, input, memory)
ut_function = functools.partial(
universal_transformer_skip,
hparams=hparams,
ffn_unit=ffn_unit,
attention_unit=attention_unit,
pad_remover=pad_remover)
elif hparams.recurrence_type == "dwa":
# memory contains the original input + all the states
memory_size = hparams.num_rec_steps + 1
# prepare initializer:
memory_empty = tf.zeros([memory_size] + common_layers.shape_list(x))
# filling the first slot with the original input
memory = fill_memory_slot(memory_empty, x, 0)
ut_initializer = (x, x, memory) # (state, input, memory)
ut_function = functools.partial(
universal_transformer_depthwise_attention,
hparams=hparams,
ffn_unit=ffn_unit,
attention_unit=attention_unit)
elif hparams.recurrence_type == "gru":
ut_initializer = (x, x, x) # (state, input, memory)
ut_function = functools.partial(
universal_transformer_with_gru_as_transition_function,
hparams=hparams,
ffn_unit=ffn_unit,
attention_unit=attention_unit,
pad_remover=pad_remover)
elif hparams.recurrence_type == "lstm":
memory = tf.zeros(common_layers.shape_list(x))
ut_initializer = (x, x, memory) # (state, input, memory)
ut_function = functools.partial(
universal_transformer_with_lstm_as_transition_function,
hparams=hparams,
ffn_unit=ffn_unit,
attention_unit=attention_unit,
pad_remover=pad_remover)
else:
raise ValueError("Unknown recurrence type: %s" % hparams.recurrence_type)
return ut_function, ut_initializer
def transformer_encoder_ffn_unit(x,
hparams,
nonpadding_mask=None,
pad_remover=None):
"""Applies a feed-forward function which is parametrised for encoding.
Args:
x: input
hparams: model hyper-parameters
nonpadding_mask: optional Tensor with shape [batch_size, encoder_length]
indicating what positions are not padding. This is used
to mask out padding in convoltutional layers. We generally only
need this mask for "packed" datasets, because for ordinary datasets,
no padding is ever followed by nonpadding.
pad_remover: to mask out padding in convolutional layers (efficiency).
Returns:
the output tensor
"""
with tf.variable_scope("ffn"):
if hparams.transformer_ffn_type == "fc":
y = transformer.transformer_ffn_layer(
common_layers.layer_preprocess(x, hparams),
hparams,
pad_remover,
conv_padding="SAME",
nonpadding_mask=nonpadding_mask)
if hparams.transformer_ffn_type == "sepconv":
assert nonpadding_mask is not None, (
"The nonpadding_mask should be provided, otherwise the model uses "
"the leaked padding information to estimate the length!")
y = common_layers.sepconv_relu_sepconv(
common_layers.layer_preprocess(x, hparams),
filter_size=hparams.filter_size,
output_size=hparams.hidden_size,
first_kernel_size=(3, 1),
second_kernel_size=(5, 1),
padding="SAME",
nonpadding_mask=nonpadding_mask,
dropout=hparams.relu_dropout)
x = common_layers.layer_postprocess(x, y, hparams)
return x
def transformer_encoder_attention_unit(x,
hparams,
encoder_self_attention_bias,
attention_dropout_broadcast_dims,
save_weights_to=None,
make_image_summary=True):
"""Applies multihead attention function which is parametrised for encoding.
Args:
x: input
hparams: model hyper-parameters
encoder_self_attention_bias: a bias tensor for use in encoder self-attention
attention_dropout_broadcast_dims: Fpr noise broadcasting in the dropout
layers to save memory during training
save_weights_to: an optional dictionary to capture attention weights for
visualization; the weights tensor will be appended there under a string
key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
Returns:
the output tensor
"""
with tf.variable_scope("self_attention"):
y = common_attention.multihead_attention(
common_layers.layer_preprocess(x, hparams),
None,
encoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
save_weights_to=save_weights_to,
max_relative_position=hparams.max_relative_position,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
hard_attention_k=hparams.hard_attention_k)
x = common_layers.layer_postprocess(x, y, hparams)
return x
def transformer_decoder_ffn_unit(x,
hparams,
nonpadding_mask=None):
"""Applies a feed-forward function which is parametrised for decoding.
Args:
x: input
hparams: model hyper-parameters
nonpadding_mask: optional Tensor with shape [batch_size, encoder_length]
indicating what positions are not padding. This is used
to mask out padding in convoltutional layers. We generally only
need this mask for "packed" datasets, because for ordinary datasets,
no padding is ever followed by nonpadding.
Returns:
the output tensor
"""
with tf.variable_scope("ffn"):
if hparams.transformer_ffn_type == "fc":
y = transformer.transformer_ffn_layer(
common_layers.layer_preprocess(x, hparams),
hparams,
conv_padding="LEFT",
nonpadding_mask=nonpadding_mask)
if hparams.transformer_ffn_type == "sepconv":
y = common_layers.sepconv_relu_sepconv(
common_layers.layer_preprocess(x, hparams),
filter_size=hparams.filter_size,
output_size=hparams.hidden_size,
first_kernel_size=(3, 1),
second_kernel_size=(5, 1),
padding="LEFT",
nonpadding_mask=nonpadding_mask,
dropout=hparams.relu_dropout)
x = common_layers.layer_postprocess(x, y, hparams)
return x
def transformer_decoder_attention_unit(x,
hparams,
encoder_output,
decoder_self_attention_bias,
encoder_decoder_attention_bias,
attention_dropout_broadcast_dims,
save_weights_to=None,
make_image_summary=True):
"""Applies multihead attention function which is parametrised for decoding.
Args:
x: input (decoder input)
hparams: model hyper-parameters
encoder_output: Encoder representation. [batch_size, input_length,
hidden_dim]
decoder_self_attention_bias: Bias and mask weights for decoder
self-attention. [batch_size, decoder_length]
encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder
attention. [batch_size, input_length]
attention_dropout_broadcast_dims: Fpr noise broadcasting in the dropout
layers to save memory during training
save_weights_to: an optional dictionary to capture attention weights for
visualization; the weights tensor will be appended there under a string
key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
Returns:
The output tensor
"""
with tf.variable_scope("self_attention"):
y = common_attention.multihead_attention(
common_layers.layer_preprocess(x, hparams),
None,
decoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
save_weights_to=save_weights_to,
max_relative_position=hparams.max_relative_position,
cache=None,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
hard_attention_k=hparams.hard_attention_k)
x = common_layers.layer_postprocess(x, y, hparams)
if encoder_output is not None:
with tf.variable_scope("encdec_attention"):
y = common_attention.multihead_attention(
common_layers.layer_preprocess(x, hparams),
encoder_output,
encoder_decoder_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
hard_attention_k=hparams.hard_attention_k)
x = common_layers.layer_postprocess(x, y, hparams)
return x
def universal_transformer_basic(layer_inputs,
step, hparams,
ffn_unit,
attention_unit):
"""Basic Universal Transformer.
This model is pretty similar to the vanilla transformer in which weights are
shared between layers. For some tasks, this simple idea brings a
generalization that is not achievable by playing with the size of the model
or drop_out parameters in the vanilla transformer.
Args:
layer_inputs:
- state: state
step: indicates number of steps taken so far
hparams: model hyper-parameters
ffn_unit: feed-forward unit
attention_unit: multi-head attention unit
Returns:
layer_output:
new_state: new state
"""
state, inputs, memory = tf.unstack(layer_inputs, num=None, axis=0,
name="unstack")
new_state = step_preprocess(state, step, hparams)
for i in range(hparams.num_inrecurrence_layers):
with tf.variable_scope("rec_layer_%d" % i):
new_state = ffn_unit(attention_unit(new_state))
return new_state, inputs, memory
def universal_transformer_highway(layer_inputs,
step, hparams,
ffn_unit,
attention_unit,
pad_remover=None):
"""Universal Transformer with highway connection.
It transforms the state using a block containing self-attention and transition
function and wrap the whole block with a highway connection.
(the new state is a combination of the state and the transformed-state
based on cary/transform gates.)
Interesting observation:
Controlling the cary/transform gate with the original inputs works usually
better (i.e. hparams.gates_inputs="i")
Args:
layer_inputs:
- state: state
- inputs: the original embedded inputs (= inputs to the first step)
step: indicates number of steps taken so far
hparams: model hyper-parameters.
ffn_unit: feed-forward unit
attention_unit: multi-head attention unit
pad_remover: to mask out padding in convolutional layers (efficiency).
Returns:
layer_output:
new_state: new state
inputs: the original embedded inputs (= inputs to the first step)
"""
state, inputs, memory = layer_inputs
new_state = step_preprocess(state, step, hparams)
for i in range(hparams.num_inrecurrence_layers):
with tf.variable_scope("rec_layer_%d" % i):
new_state = ffn_unit(attention_unit(new_state))
transformed_state = new_state
gate_inputs = []
if "s" in hparams.gates_inputs:
gate_inputs.append(state)
if "t" in hparams.gates_inputs:
gate_inputs.append(transformed_state)
if "i" in hparams.gates_inputs:
gate_inputs.append(inputs)
gate_ffn_layer = hparams.gate_ffn_layer
transform_gate = _ffn_layer_multi_inputs(
gate_inputs,
hparams,
ffn_layer_type=gate_ffn_layer,
name="transform",
bias_initializer=tf.constant_initializer(hparams.transform_bias_init),
activation=tf.sigmoid,
pad_remover=pad_remover,
preprocess=True)
if hparams.couple_carry_transform_gates:
carry_gate = tf.subtract(1.0, transform_gate, name="carry")
else:
carry_gate = _ffn_layer_multi_inputs(
gate_inputs,
hparams,
ffn_layer_type=gate_ffn_layer,
name="carry",
bias_initializer=tf.constant_initializer(-hparams.transform_bias_init),
activation=tf.sigmoid,
pad_remover=pad_remover,
preprocess=True)
new_state = state * carry_gate + transformed_state * transform_gate
tf.contrib.summary.scalar("highway_transform_gate_layer",
tf.reduce_mean(transform_gate))
tf.contrib.summary.scalar("highway_carry_gate_layer",
tf.reduce_mean(carry_gate))
return new_state, inputs, memory
def universal_transformer_skip(layer_inputs,
step,
hparams,
ffn_unit,
attention_unit,
pad_remover=None):
"""Universal Transformer with highway connection.
It transforms the state using attention and ffn and wrap this transformation
with a skip-all connection. (the new state is a combination of the state and
the inputs (original inputs) based on cary/transform gates.)
Observation:
Controlling the cary/transform gate with the original inputs works usually
better (i.e. hparams.gates_inputs="i")
Args:
layer_inputs:
- state: state
- inputs: the original embedded inputs (= inputs to the first step)
step: indicates number of steps taken so far
hparams: model hyper-parameters.
ffn_unit: feed-forward unit
attention_unit: multi-head attention unit
pad_remover: to mask out padding in convolutional layers (efficiency).
Returns:
layer_output:
new_state: new state
inputs: the original embedded inputs (= inputs to the first step)
"""
state, inputs, memory = layer_inputs
new_state = step_preprocess(state, step, hparams)
for i in range(hparams.num_inrecurrence_layers):
with tf.variable_scope("rec_layer_%d" % i):
new_state = ffn_unit(attention_unit(new_state))
transformed_state = new_state
inputs.get_shape().assert_is_compatible_with(state.get_shape())
gate_inputs = []
if "s" in hparams.gates_inputs:
gate_inputs.append(state)
if "t" in hparams.gates_inputs:
gate_inputs.append(transformed_state)
if "i" in hparams.gates_inputs:
gate_inputs.append(inputs)
gate_ffn_layer = hparams.gate_ffn_layer
transform_gate = _ffn_layer_multi_inputs(
gate_inputs,
hparams,
ffn_layer_type=gate_ffn_layer,
name="transform",
bias_initializer=tf.constant_initializer(hparams.transform_bias_init),
activation=tf.sigmoid,
pad_remover=pad_remover,
preprocess=True)
if hparams.couple_carry_transform_gates:
carry_gate = tf.subtract(1.0, transform_gate, name="carry")
else:
carry_gate = _ffn_layer_multi_inputs(
gate_inputs,
hparams,
ffn_layer_type=gate_ffn_layer,
name="carry",
bias_initializer=tf.constant_initializer(-hparams.transform_bias_init),
activation=tf.sigmoid,
pad_remover=pad_remover,
preprocess=True)
tf.contrib.summary.scalar("skip_transform_gate_layer",
tf.reduce_mean(transform_gate))
tf.contrib.summary.scalar("skip_carry_gate_layer", tf.reduce_mean(carry_gate))
new_state = inputs * carry_gate + transformed_state * transform_gate
return new_state, inputs, memory
def universal_transformer_depthwise_attention(layer_inputs,
step, hparams,
ffn_unit,
attention_unit):
"""universal_transformer with depth-wise attention.
It uses an attention mechanism-flipped vertically-
over all the states from previous steps to generate the new_state.
Args:
layer_inputs:
- state: state
- memory: contains states from all the previous steps.
step: indicating number of steps take so far
hparams: model hyper-parameters.
ffn_unit: feed-forward unit
attention_unit: multi-head attention unit
Returns:
layer_output:
new_state: new state
memory: contains states from all the previous steps.
"""
_, inputs, memory = layer_inputs
all_states = memory
# add depth signal
if hparams.depth_embedding:
all_states = add_depth_embedding(all_states)
# get the states up to the current step (non-zero part of the memory)
states_so_far = all_states[:step, :, :, :]
states_so_far_weights = tf.nn.softmax(
common_layers.dense(
states_so_far, (hparams.hidden_size if hparams.dwa_elements else 1),
activation=None,
use_bias=True),
axis=-1)
# prepare the state tensor that will be transformed
state_to_be_transformed = tf.reduce_sum(
(states_so_far * states_so_far_weights), axis=0)
new_state = step_preprocess(state_to_be_transformed, step, hparams)
for i in range(hparams.num_inrecurrence_layers):
with tf.variable_scope("rec_layer_%d" % i):
new_state = ffn_unit(attention_unit(new_state))
# add the new state to the memory
memory = fill_memory_slot(memory, new_state, step + 1)
return new_state, inputs, memory
def universal_transformer_with_gru_as_transition_function(
layer_inputs, step, hparams, ffn_unit, attention_unit, pad_remover=None):
"""Universal Transformer which uses a gru as transition function.
It's kind of like having a gru, filliped vertically next to the Universal
Transformer that controls the flow of the information in depth,
over different steps of the Universal Transformer.
Args:
layer_inputs:
- state: state
- inputs: not used here
- memory: not used here
step: indicates number of steps taken so far
hparams: model hyper-parameters.
ffn_unit: feed-forward unit
attention_unit: multi-head attention unit
pad_remover: to mask out padding in convolutional layers (efficiency).
Returns:
layer_output:
new_state: new state
inputs: not uesed
memory: not used
"""
state, unused_inputs, unused_memory = tf.unstack(
layer_inputs, num=None, axis=0, name="unstack")
# state (ut_state): output of the gru in the previous step
# Multi_head_attention:
assert not hparams.add_step_timing_signal # Let gru count for us!
mh_attention_input = step_preprocess(state, step, hparams)
transition_function_input = attention_unit(mh_attention_input)
# Transition Function:
if hparams.add_ffn_unit_to_the_transition_function:
transition_function_input = ffn_unit(transition_function_input)
transition_function_input = common_layers.layer_preprocess(
transition_function_input, hparams)
with tf.variable_scope("gru"):
# gru update gate: z_t = sigmoid(W_z.x_t + U_z.h_{t-1})
transition_function_update_gate = _ffn_layer_multi_inputs(
[transition_function_input, state],
hparams,
name="update",
bias_initializer=tf.constant_initializer(1.0),
activation=tf.sigmoid,
pad_remover=pad_remover)
tf.contrib.summary.scalar("gru_update_gate",
tf.reduce_mean(transition_function_update_gate))
# gru reset gate: r_t = sigmoid(W_r.x_t + U_r.h_{t-1})
transition_function_reset_gate = _ffn_layer_multi_inputs(
[transition_function_input, state],
hparams,
name="reset",
bias_initializer=tf.constant_initializer(1.0),
activation=tf.sigmoid,
pad_remover=pad_remover)
tf.contrib.summary.scalar("gru_reset_gate",
tf.reduce_mean(transition_function_reset_gate))
reset_state = transition_function_reset_gate * state
# gru_candidate_activation: h' = tanh(W_{x_t} + U (r_t h_{t-1})
transition_function_candidate = _ffn_layer_multi_inputs(
[transition_function_input, reset_state],
hparams,
name="candidate",
bias_initializer=tf.zeros_initializer(),
activation=tf.tanh,
pad_remover=pad_remover)
transition_function_output = (
(1 - transition_function_update_gate) * transition_function_input +
transition_function_update_gate * transition_function_candidate)
transition_function_output = common_layers.layer_preprocess(
transition_function_output, hparams)
return transition_function_output, unused_inputs, unused_memory
def universal_transformer_with_lstm_as_transition_function(
layer_inputs, step, hparams, ffn_unit, attention_unit, pad_remover=None):
"""Universal Transformer which uses a lstm as transition function.
It's kind of like having a lstm, filliped vertically next to the Universal
Transformer that controls the flow of the information in depth,
over different steps of the Universal Transformer.
Args:
layer_inputs:
- state: state
- inputs: the original embedded inputs (= inputs to the first step)
- memory: memory used in lstm.
step: indicates number of steps taken so far
hparams: model hyper-parameters.
ffn_unit: feed-forward unit
attention_unit: multi-head attention unit
pad_remover: to mask out padding in convolutional layers (efficiency).
Returns:
layer_output:
new_state: new state
inputs: the original embedded inputs (= inputs to the first step)
memory: contains information of state from all the previous steps.
"""
state, unused_inputs, memory = tf.unstack(
layer_inputs, num=None, axis=0, name="unstack")
# NOTE:
# state (ut_state): output of the lstm in the previous step
# inputs (ut_input): original input --> we don't use it here
# memory: lstm memory
# Multi_head_attention:
assert not hparams.add_step_timing_signal # Let lstm count for us!
mh_attention_input = step_preprocess(state, step, hparams)
transition_function_input = attention_unit(mh_attention_input)
# Transition Function:
if hparams.add_ffn_unit_to_the_transition_function:
transition_function_input = ffn_unit(transition_function_input)
transition_function_input = common_layers.layer_preprocess(
transition_function_input, hparams)
with tf.variable_scope("lstm"):
# lstm input gate: i_t = sigmoid(W_i.x_t + U_i.h_{t-1})
transition_function_input_gate = _ffn_layer_multi_inputs(
[transition_function_input, state],
hparams,
name="input",
bias_initializer=tf.zeros_initializer(),
activation=tf.sigmoid,
pad_remover=pad_remover)
tf.contrib.summary.scalar("lstm_input_gate",
tf.reduce_mean(transition_function_input_gate))
# lstm forget gate: f_t = sigmoid(W_f.x_t + U_f.h_{t-1})
transition_function_forget_gate = _ffn_layer_multi_inputs(
[transition_function_input, state],
hparams,
name="forget",
bias_initializer=tf.zeros_initializer(),
activation=None,
pad_remover=pad_remover)
forget_bias_tensor = tf.constant(hparams.lstm_forget_bias)
transition_function_forget_gate = tf.sigmoid(
transition_function_forget_gate + forget_bias_tensor)
tf.contrib.summary.scalar("lstm_forget_gate",
tf.reduce_mean(transition_function_forget_gate))
# lstm output gate: o_t = sigmoid(W_o.x_t + U_o.h_{t-1})
transition_function_output_gate = _ffn_layer_multi_inputs(
[transition_function_input, state],
hparams,
name="output",
bias_initializer=tf.zeros_initializer(),
activation=tf.sigmoid,
pad_remover=pad_remover)
tf.contrib.summary.scalar("lstm_output_gate",
tf.reduce_mean(transition_function_output_gate))
# lstm input modulation
transition_function_input_modulation = _ffn_layer_multi_inputs(
[transition_function_input, state],
hparams,
name="input_modulation",
bias_initializer=tf.zeros_initializer(),
activation=tf.tanh,
pad_remover=pad_remover)
transition_function_memory = (
memory * transition_function_forget_gate +
transition_function_input_gate * transition_function_input_modulation)
transition_function_output = (
tf.tanh(transition_function_memory) * transition_function_output_gate)
transition_function_output = common_layers.layer_preprocess(
transition_function_output, hparams)
return transition_function_output, unused_inputs, transition_function_memory
def universal_transformer_act(x, hparams, ffn_unit, attention_unit):
"""ACT based models.
Implementations of all act models are based on craffel@'s cl/160711592.
(1) Basic AUT based on remainder-distribution ACT (position-wise).
(2) AUT with global halting probability (not position-wise).
(3) AUT with random halting probability (not position-wise).
(4) AUT with final state as accumulation of all states.
Args:
x: input
hparams: model hyper-parameters
ffn_unit: feed-forward unit
attention_unit: multi-head attention unit
Returns:
the output tensor, (ponder_times, remainders)
Raises:
ValueError: Unknown act type
"""
if hparams.act_type not in ["basic", "global", "random", "accumulated"]:
raise ValueError("Unknown act type: %s" % hparams.act_type)
state = x
act_max_steps = hparams.act_max_steps
threshold = 1.0 - hparams.act_epsilon
state_shape_static = state.get_shape()
state_slice = slice(0, 2)
if hparams.act_type == "global":
state_slice = slice(0, 1)
# Dynamic shape for update tensors below
update_shape = tf.shape(state)[state_slice]
# Halting probabilities (p_t^n in the paper)
halting_probability = tf.zeros(update_shape, name="halting_probability")
# Remainders (R(t) in the paper)
remainders = tf.zeros(update_shape, name="remainder")
# Number of updates performed (N(t) in the paper)
n_updates = tf.zeros(update_shape, name="n_updates")
# Previous cell states (s_t in the paper)
previous_state = tf.zeros_like(state, name="previous_state")
step = tf.constant(0, dtype=tf.int32)
def ut_function(state, step, halting_probability, remainders, n_updates,
previous_state):
"""implements act (position-wise halting).
Args:
state: 3-D Tensor: [batch_size, length, channel]
step: indicates number of steps taken so far
halting_probability: halting probability
remainders: act remainders
n_updates: act n_updates
previous_state: previous state
Returns:
transformed_state: transformed state
step: step+1
halting_probability: halting probability
remainders: act remainders
n_updates: act n_updates
new_state: new state
"""
state = step_preprocess(state, step, hparams)
if hparams.act_type == "random":
# random as halting probability
p = tf.random_uniform(
shape=common_layers.shape_list(halting_probability))
else:
with tf.variable_scope("sigmoid_activation_for_pondering"):
p = common_layers.dense(
state,
1,
activation=tf.nn.sigmoid,
use_bias=True,
bias_initializer=tf.constant_initializer(
hparams.act_halting_bias_init))
if hparams.act_type == "global":
# average over all positions (as a global halting prob)
p = tf.reduce_mean(p, axis=1)
p = tf.squeeze(p)
else:
# maintain position-wise probabilities
p = tf.squeeze(p, axis=-1)
# Mask for inputs which have not halted yet
still_running = tf.cast(tf.less(halting_probability, 1.0), tf.float32)
# Mask of inputs which halted at this step
new_halted = tf.cast(
tf.greater(halting_probability + p * still_running, threshold),
tf.float32) * still_running
# Mask of inputs which haven't halted, and didn't halt this step
still_running = tf.cast(
tf.less_equal(halting_probability + p * still_running, threshold),
tf.float32) * still_running
# Add the halting probability for this step to the halting
# probabilities for those input which haven't halted yet
halting_probability += p * still_running
# Compute remainders for the inputs which halted at this step
remainders += new_halted * (1 - halting_probability)
# Add the remainders to those inputs which halted at this step
halting_probability += new_halted * remainders
# Increment n_updates for all inputs which are still running
n_updates += still_running + new_halted
# Compute the weight to be applied to the new state and output
# 0 when the input has already halted
# p when the input hasn't halted yet
# the remainders when it halted this step
update_weights = tf.expand_dims(
p * still_running + new_halted * remainders, -1)
if hparams.act_type == "global":
update_weights = tf.expand_dims(update_weights, -1)
# apply transformation on the state
transformed_state = state
for i in range(hparams.num_inrecurrence_layers):
with tf.variable_scope("rec_layer_%d" % i):
transformed_state = ffn_unit(attention_unit(transformed_state))
# update running part in the weighted state and keep the rest
new_state = ((transformed_state * update_weights) +
(previous_state * (1 - update_weights)))
if hparams.act_type == "accumulated":
# Add in the weighted state
new_state = (transformed_state * update_weights) + previous_state
# remind TensorFlow of everything's shape
transformed_state.set_shape(state_shape_static)
for x in [halting_probability, remainders, n_updates]:
x.set_shape(state_shape_static[state_slice])
new_state.set_shape(state_shape_static)
step += 1
return (transformed_state, step, halting_probability, remainders, n_updates,
new_state)
# While loop stops when this predicate is FALSE.
# Ie all (probability < 1-eps AND counter < N) are false.
def should_continue(u0, u1, halting_probability, u2, n_updates, u3):
del u0, u1, u2, u3
return tf.reduce_any(
tf.logical_and(
tf.less(halting_probability, threshold),
tf.less(n_updates, act_max_steps)))
# Do while loop iterations until predicate above is false.
(_, _, _, remainder, n_updates, new_state) = tf.while_loop(
should_continue, ut_function,
(state, step, halting_probability, remainders, n_updates, previous_state),
maximum_iterations=act_max_steps + 1)
ponder_times = n_updates
remainders = remainder
tf.contrib.summary.scalar("ponder_times", tf.reduce_mean(ponder_times))
return new_state, (ponder_times, remainders)
def _ffn_layer_multi_inputs(inputs_list,
hparams,
output_size=None,
ffn_layer_type="dense",
name="ffn",
kernel_initializer=None,
bias_initializer=None,
activation=None,
pad_remover=None,
preprocess=False,
postprocess=False):
"""Implements a Feed-forward layer with multiple inputs, pad-removing, etc.
Args:
inputs_list: list of input tensors
hparams: hyper-parameters
output_size: dimentionality of the output
ffn_layer_type: dense / dense_dropconnect/ dense_relu_dense
name: name
kernel_initializer: kernel initializer
bias_initializer: bias initializer
activation: activation function
pad_remover: pad remover
preprocess: if preprocess the input --> default: layer-norm
postprocess: if postprocess the output --> default: drop-out and residual
Returns:
a tensor
Raises:
ValueError: Unknown ffn_layer type.
"""
# need at least one inputs
num_inputs = len(inputs_list)
assert num_inputs > 0
if preprocess:
# In case of having more than one input to the ffn,
# we just apply layer norm on them independently as preprocessing
for i, inputs in enumerate(inputs_list):
inputs_list[i] = common_layers.layer_preprocess(inputs_list[i], hparams)
# for the residual connection
if postprocess and num_inputs == 1:
original_inputs = inputs_list[0]
# the output size is the hidden size of the main inputs
main_input = inputs_list[0]
original_shape = common_layers.shape_list(main_input)
assert hparams.hidden_size == common_layers.shape_list(main_input)[-1]
# all the inputs are in the same shape with main inputs
for inputs in inputs_list:
main_input.get_shape().assert_is_compatible_with(inputs.get_shape())
def remove_pads(x):
original_shape = common_layers.shape_list(x)
# Collapse `x` across examples, and remove padding positions.
x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0))
x = tf.expand_dims(pad_remover.remove(x), axis=0)
return x
if pad_remover:
for i, inputs in enumerate(inputs_list):
inputs_list[i] = remove_pads(inputs)
ffn_inputs = inputs_list[0]
if len(inputs_list) != 1:
ffn_inputs = tf.concat(inputs_list, axis=-1)
if ffn_layer_type == "dense":
output = common_layers.dense(
ffn_inputs,
hparams.hidden_size if output_size is None else output_size,
name=name,
activation=activation,
use_bias=True,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer)
elif ffn_layer_type == "dense_dropconnect":
output = common_layers.dense_dropconnect(
ffn_inputs,
hparams.hidden_size if output_size is None else output_size,
name=name,
dropconnect_dropout=hparams.dropconnect_dropout,
output_activation=activation)
postprocess = False # no dropout on the output unit
elif ffn_layer_type == "dense_relu_dense":
output = common_layers.dense_relu_dense(
ffn_inputs,
hparams.filter_size,
hparams.hidden_size if output_size is None else output_size,
name=name,
dropout=hparams.relu_dropout,
output_activation=activation,
)
else:
raise ValueError("Unknown ffn_layer type: %s" % ffn_layer_type)
if pad_remover:
# Restore `output` to the original shape of `x`, including padding.
output = tf.reshape(
pad_remover.restore(tf.squeeze(output, axis=0)), original_shape)
if postprocess:
if num_inputs == 1:
output = common_layers.layer_postprocess(original_inputs, output, hparams)
else: # only dropout (no residual)
hp = copy.copy(hparams)
hp.layer_postprocess_sequence = hp.layer_postprocess_sequence.replace(
"a", "")
output = common_layers.layer_postprocess(None, output, hp)
return output
def fill_memory_slot(memory, value, index):
"""Fills the memory slot at a particular index with the given value.
Args:
memory: a 4-d tensor [memory_size, batch, length, channel] containing
the state of all steps
value: a 3-d tensor [batch, length, channel] as the sate
index: integer in [0, memory_size)
Returns:
filled memory
"""
mask = tf.to_float(
tf.one_hot(index,
tf.shape(memory)[0])[:, None, None, None])
fill_memory = (1 - mask) * memory + mask * value[None, ...]
return fill_memory
def add_depth_embedding(x):
"""Add n-dimensional embedding as the depth embedding (timing signal).
Adds embeddings to represent the position of the step in the recurrent
tower.
Args:
x: a tensor with shape [max_step, batch, length, depth]
Returns:
a Tensor the same shape as x.
"""
x_shape = common_layers.shape_list(x)
depth = x_shape[-1]
num_steps = x_shape[0]
shape = [num_steps, 1, 1, depth]
depth_embedding = (
tf.get_variable(
"depth_embedding",
shape,
initializer=tf.random_normal_initializer(0, depth**-0.5)) * (depth**
0.5))
x += depth_embedding
return x
def step_preprocess(x, step, hparams):
"""Preprocess the input at the beginning of each step.
Args:
x: input tensor
step: step
hparams: model hyper-parameters
Returns:
preprocessed input.
"""
original_channel_size = common_layers.shape_list(x)[-1]
if hparams.add_position_timing_signal:
x = add_position_timing_signal(x, step, hparams)
if hparams.add_step_timing_signal:
x = add_step_timing_signal(x, step, hparams)
if ((hparams.add_position_timing_signal or hparams.add_position_timing_signal)
and hparams.add_or_concat_timing_signal == "concat"):
# linear projection to the original dimension of x
x = common_layers.dense(
x, original_channel_size, activation=None, use_bias=False)
if hparams.add_sru:
x = common_layers.sru(x)
return x
def add_position_timing_signal(x, step, hparams):
"""Add n-dimensional embedding as the position (horizontal) timing signal.
Args:
x: a tensor with shape [batch, length, depth]
step: step
hparams: model hyper parameters
Returns:
a Tensor with the same shape as x.
"""
if not hparams.position_start_index:
index = 0
elif hparams.position_start_index == "random":
# Shift all positions randomly
# TODO(dehghani): What would be reasonable for max number of shift?
index = tf.random_uniform(
[], maxval=common_layers.shape_list(x)[1], dtype=tf.int32)
elif hparams.position_start_index == "step":
# Shift positions based on the step
if hparams.recurrence_type == "act":
num_steps = hparams.act_max_steps
else:
num_steps = hparams.num_rec_steps
index = tf.cast(
common_layers.shape_list(x)[1] * step / num_steps, dtype=tf.int32)
# No need for the timing signal in the encoder/decoder input preparation
assert hparams.pos is None
length = common_layers.shape_list(x)[1]
channels = common_layers.shape_list(x)[2]
signal = common_attention.get_timing_signal_1d(
length, channels, start_index=index)
if hparams.add_or_concat_timing_signal == "add":
x_with_timing = x + common_layers.cast_like(signal, x)
elif hparams.add_or_concat_timing_signal == "concat":
batch_size = common_layers.shape_list(x)[0]
signal_tiled = tf.tile(signal, [batch_size, 1, 1])
x_with_timing = tf.concat((x, signal_tiled), axis=-1)
return x_with_timing
def add_step_timing_signal(x, step, hparams):
"""Add n-dimensional embedding as the step (vertical) timing signal.
Args:
x: a tensor with shape [batch, length, depth]
step: step
hparams: model hyper parameters
Returns:
a Tensor with the same shape as x.
"""
if hparams.recurrence_type == "act":
num_steps = hparams.act_max_steps
else:
num_steps = hparams.num_rec_steps
channels = common_layers.shape_list(x)[-1]
if hparams.step_timing_signal_type == "learned":
signal = common_attention.get_layer_timing_signal_learned_1d(
channels, step, num_steps)
elif hparams.step_timing_signal_type == "sinusoid":
signal = common_attention.get_layer_timing_signal_sinusoid_1d(
channels, step, num_steps)
if hparams.add_or_concat_timing_signal == "add":
x_with_timing = x + common_layers.cast_like(signal, x)
elif hparams.add_or_concat_timing_signal == "concat":
batch_size = common_layers.shape_list(x)[0]
length = common_layers.shape_list(x)[1]
signal_tiled = tf.tile(signal, [batch_size, length, 1])
x_with_timing = tf.concat((x, signal_tiled), axis=-1)
return x_with_timing
You can’t perform that action at this time.