Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
0 contributors

Users who have contributed to this file

434 lines (407 sloc) 18.8 KB
# coding=utf-8
# Copyright 2019 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for area attention."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
from tensor2tensor.layers import common_layers
import tensorflow as tf
def lengths_to_area_mask(feature_length, length, max_area_size):
"""Generates a non-padding mask for areas based on lengths.
Args:
feature_length: a tensor of [batch_size]
length: the length of the batch
max_area_size: the maximum area size considered
Returns:
mask: a tensor in shape of [batch_size, num_areas]
"""
paddings = tf.cast(tf.expand_dims(
tf.logical_not(
tf.sequence_mask(feature_length, maxlen=length)), 2), tf.float32)
_, _, area_sum, _, _ = compute_area_features(paddings,
max_area_width=max_area_size)
mask = tf.squeeze(tf.logical_not(tf.cast(area_sum, tf.bool)), [2])
return mask
def _pool_one_shape(features_2d, area_width, area_height, batch_size,
width, height, depth, fn=tf.reduce_max, name=None):
"""Pools for an area in features_2d.
Args:
features_2d: a Tensor in a shape of [batch_size, height, width, depth].
area_width: the max width allowed for an area.
area_height: the max height allowed for an area.
batch_size: the batch size.
width: the width of the memory.
height: the height of the memory.
depth: the depth of the features.
fn: the TF function for the pooling.
name: the op name.
Returns:
pool_tensor: A Tensor of shape [batch_size, num_areas, depth]
"""
with tf.name_scope(name, default_name="pool_one_shape"):
images = []
for y_shift in range(area_height):
image_height = tf.maximum(height - area_height + 1 + y_shift, 0)
for x_shift in range(area_width):
image_width = tf.maximum(width - area_width + 1 + x_shift, 0)
area = features_2d[:, y_shift:image_height, x_shift:image_width, :]
flatten_area = tf.reshape(area, [batch_size, -1, depth, 1])
images.append(flatten_area)
image_tensor = tf.concat(images, axis=3)
max_tensor = fn(image_tensor, axis=3)
return max_tensor
def basic_pool(features, max_area_width, max_area_height=1, height=1,
fn=tf.reduce_max, name=None):
"""Pools for each area based on a given pooling function (fn).
Args:
features: a Tensor in a shape of [batch_size, height * width, depth].
max_area_width: the max width allowed for an area.
max_area_height: the max height allowed for an area.
height: the height of the image.
fn: the TF function for the pooling.
name: the namescope.
Returns:
pool_results: A Tensor of shape [batch_size, num_areas, depth]
area_heights: A Tensor of shape [batch_size, num_areas, 1]
area_widths: A Tensor of shape [batch_size, num_areas, 1]
"""
with tf.name_scope(name, default_name="basic_pool"):
feature_shape = common_layers.shape_list(features)
batch_size = feature_shape[0]
length = feature_shape[-2]
depth = feature_shape[-1]
width = length // height
features_2d = tf.reshape(features, [batch_size, height, width, depth])
height_list = []
width_list = []
pool_list = []
size_tensor = tf.ones_like(features_2d[:, :, :, 0], dtype=tf.int32)
for area_height in range(max_area_height):
for area_width in range(max_area_width):
pool_tensor = _pool_one_shape(features_2d,
area_width=area_width + 1,
area_height=area_height + 1,
batch_size=batch_size,
width=width,
height=height,
depth=depth,
fn=fn)
pool_list.append(
tf.reshape(pool_tensor, [batch_size, -1, depth]))
height_list.append(
tf.reshape(
size_tensor[:, area_height:, area_width:] *\
(area_height + 1), [batch_size, -1]))
width_list.append(
tf.reshape(
size_tensor[:, area_height:, area_width:] *\
(area_width + 1), [batch_size, -1]))
pool_results = tf.concat(pool_list, axis=1)
area_heights = tf.expand_dims(tf.concat(height_list, axis=1), 2)
area_widths = tf.expand_dims(tf.concat(width_list, axis=1), 2)
return pool_results, area_heights, area_widths
def _compute_sum_image(features, max_area_width, max_area_height=1, height=1,
name=None):
"""Computes area sums for features.
Args:
features: a Tensor in a shape of [batch_size, height * width, depth].
max_area_width: the max width allowed for an area.
max_area_height: the max height allowed for an area.
height: the height of the image.
name: the namescope.
Returns:
sum_image: A Tensor of shape [batch_size, num_areas, depth]
area_heights: A Tensor of shape [batch_size, num_areas, 1]
area_widths: A Tensor of shape [batch_size, num_areas, 1]
"""
with tf.name_scope(name, default_name="compute_sum_image"):
feature_shape = common_layers.shape_list(features)
batch_size = feature_shape[0]
length = feature_shape[-2]
depth = feature_shape[-1]
width = length // height
features_2d = tf.reshape(features, [batch_size, height, width, depth])
width_cum = tf.cumsum(features_2d, axis=-2, name="compute_integral_h")
integral_image = tf.cumsum(width_cum, axis=-3, name="compute_integral_v")
padded_image = tf.pad(
integral_image, [[0, 0], [1, 0], [1, 0], [0, 0]], constant_values=0)
height_list = []
width_list = []
dst_images = []
src_images_diag = []
src_images_h = []
src_images_v = []
size_tensor = tf.ones_like(padded_image[:, :, :, 0],
dtype=tf.int32)
for area_height in range(max_area_height):
for area_width in range(max_area_width):
dst_images.append(
tf.reshape(
padded_image[:, area_height + 1:, area_width + 1:, :],
[batch_size, -1, depth]))
src_images_diag.append(
tf.reshape(
padded_image[:, :-area_height - 1, :-area_width - 1, :],
[batch_size, -1, depth]))
src_images_h.append(
tf.reshape(
padded_image[:, area_height + 1:, :-area_width - 1, :],
[batch_size, -1, depth]))
src_images_v.append(
tf.reshape(
padded_image[:, :-area_height - 1, area_width + 1:, :],
[batch_size, -1, depth]))
height_list.append(
tf.reshape(
size_tensor[:, area_height + 1:, area_width + 1:] *\
(area_height + 1), [batch_size, -1]))
width_list.append(
tf.reshape(
size_tensor[:, area_height + 1:, area_width + 1:] *\
(area_width + 1), [batch_size, -1]))
sum_image = tf.subtract(
tf.concat(dst_images, axis=1) + tf.concat(src_images_diag, axis=1),
tf.concat(src_images_v, axis=1) + tf.concat(src_images_h, axis=1))
area_heights = tf.expand_dims(tf.concat(height_list, axis=1), 2)
area_widths = tf.expand_dims(tf.concat(width_list, axis=1), 2)
return sum_image, area_heights, area_widths
def compute_area_features(features, max_area_width, max_area_height=1, height=1,
epsilon=1e-6):
"""Computes features for each area.
Args:
features: a Tensor in a shape of [batch_size, height * width, depth].
max_area_width: the max width allowed for an area.
max_area_height: the max height allowed for an area.
height: the height of the image.
epsilon: the epsilon added to the variance for computing standard deviation.
Returns:
area_mean: A Tensor of shape [batch_size, num_areas, depth]
area_std: A Tensor of shape [batch_size, num_areas, depth]
area_sum: A Tensor of shape [batch_size, num_areas, depth]
area_heights: A Tensor of shape [batch_size, num_areas, 1]
area_widths: A Tensor of shape [batch_size, num_areas, 1]
"""
with tf.name_scope("compute_area_features"):
tf.logging.info("area_attention compute_area_features: %d x %d",
max_area_height, max_area_width)
area_sum, area_heights, area_widths = _compute_sum_image(
features, max_area_width=max_area_width,
max_area_height=max_area_height, height=height)
area_squared_sum, _, _ = _compute_sum_image(
tf.pow(features, 2), max_area_width=max_area_width,
max_area_height=max_area_height, height=height)
sizes = tf.multiply(area_heights, area_widths)
float_area_sizes = tf.to_float(sizes)
area_mean = tf.div(area_sum, float_area_sizes)
s2_n = tf.div(area_squared_sum, float_area_sizes)
area_variance = tf.subtract(s2_n, tf.pow(area_mean, 2))
area_std = tf.sqrt(tf.abs(area_variance) + epsilon)
return area_mean, area_std, area_sum, area_heights, area_widths
def compute_area_key(features, max_area_width, max_area_height=1, height=1,
mode="mean", training=True, name=None):
"""Computes the key for each area.
Args:
features: a Tensor in a shape of [batch_size, height * width, depth].
max_area_width: the max width allowed for an area.
max_area_height: the max height allowed for an area.
height: the height of the image.
mode: whether to combine different area features or only use
the vector mean of each area, which can be "mean", "concat", "sum",
"sample_concat", and "sample_sum".
training: indicating if it is in the training mode.
name: the name for setting the variable scope.
Returns:
area_key: a Tensor in the shape of [batch_size, num_areas, depth]
"""
tf.logging.info("area_attention mode=%s", mode)
area_mean, area_std, _, area_heights, area_widths =\
compute_area_features(features, max_area_width=max_area_width,
max_area_height=max_area_height, height=height)
if mode == "mean":
return area_mean
elif mode == "max":
area_max, _, _ = basic_pool(features, max_area_width=max_area_width,
max_area_height=max_area_height, height=height)
return area_max
elif mode == "sample":
if training:
area_mean += (area_std * tf.random_normal(tf.shape(area_std)))
return area_mean
with tf.variable_scope(
name, default_name="combine_area_features",
values=[area_mean, area_std, area_heights, area_widths]):
depth = common_layers.shape_list(area_mean)[-1]
height_embed = tf.nn.embedding_lookup(
params=tf.get_variable("area_height_emb",
[max_area_height, depth // 2]),
ids=area_heights[:, :, 0] - 1)
width_embed = tf.nn.embedding_lookup(
params=tf.get_variable("area_width_emb",
[max_area_width, depth // 2]),
ids=area_widths[:, :, 0] - 1)
size_embed = tf.concat([height_embed, width_embed], -1)
if mode == "concat":
feature_concat = tf.concat([area_mean, area_std, size_embed], -1)
elif mode == "max_concat":
area_max, _, _ = basic_pool(features, max_area_width=max_area_width,
max_area_height=max_area_height,
height=height)
feature_concat = tf.concat([area_max, size_embed], -1)
elif mode == "sum":
feature_concat = size_embed + area_mean + area_std
elif mode == "sample_concat":
if training:
area_mean += (area_std * tf.random_normal(tf.shape(area_std)))
feature_concat = tf.concat([area_mean, size_embed], -1)
elif mode == "sample_sum":
if training:
area_mean += (area_std * tf.random_normal(tf.shape(area_std)))
feature_concat = area_mean + size_embed
else:
raise ValueError("Unsupported area key mode=%s" % mode)
feature_hidden = tf.layers.dense(inputs=feature_concat,
units=depth,
activation=tf.nn.relu)
area_key = tf.layers.dense(feature_hidden, units=depth)
return area_key
def dot_product_area_attention(q,
k,
v,
bias,
dropout_rate=0.0,
image_shapes=None,
name=None,
attention_image_summary=None,
save_weights_to=None,
dropout_broadcast_dims=None,
max_area_width=1,
max_area_height=1,
memory_height=1,
area_key_mode="mean",
area_value_mode="sum",
top_k_areas=0,
area_temperature=1.0,
training=True):
"""Dot-product area attention.
Args:
q: Tensor with shape [..., length_q, depth_k].
k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
match with q.
v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
match with q.
bias: bias Tensor (see attention_bias())
dropout_rate: a float.
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
name: an optional string
attention_image_summary: the callback for making image summary of attention.
save_weights_to: an optional dictionary to capture attention weights
for visualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
dropout_broadcast_dims: an optional list of integers less than rank of q.
Specifies in which dimensions to broadcast the dropout decisions.
max_area_width: the max width allowed for an area.
max_area_height: the max height allowed for an area.
memory_height: the height of the memory.
area_key_mode: the mode for computing area keys, which can be "mean",
"concat", "sum", "sample_concat", and "sample_sum".
area_value_mode: the mode for computing area values, which can be either
"mean", or "sum".
top_k_areas: Use the top key areas for attention.
area_temperature: the temperature for attention softmax.
training: indicating if it is in the training mode.
Returns:
Tensor with shape [..., length_q, depth_v].
"""
tf.logging.info("dot_product_area_attention: "
"area_h=%d, area_w=%d, mem_h=%d, "
"area_key_mode=%s, area_value_mode=%s, "
"area_temperature=%f",
max_area_height, max_area_width, memory_height,
area_key_mode, area_value_mode,
area_temperature)
with tf.variable_scope(
name, default_name="dot_product_area_attention",
values=[q, k, v]) as scope:
mem_shape = common_layers.shape_list(k)
batch_size = mem_shape[0]
head_size = mem_shape[1]
length = mem_shape[2]
depth = mem_shape[3]
k_area = compute_area_key(
tf.reshape(k, [-1, length, depth]),
max_area_width=max_area_width,
max_area_height=max_area_height,
height=memory_height,
mode=area_key_mode,
training=training)
if area_value_mode == "mean":
v_area, _, _, _, _ = compute_area_features(
tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width,
max_area_height=max_area_height, height=memory_height)
elif area_value_mode == "max":
v_area, _, _ = basic_pool(tf.reshape(v, [-1, length, depth]),
max_area_width=max_area_width,
max_area_height=max_area_height,
height=memory_height,
fn=tf.reduce_max)
elif area_value_mode == "sum":
_, _, v_area, _, _ = compute_area_features(
tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width,
max_area_height=max_area_height, height=memory_height)
else:
raise ValueError("Unsupported area value mode=%s" % area_value_mode)
k = tf.reshape(k_area, [batch_size, head_size, -1, depth])
v = tf.reshape(v_area, [batch_size, head_size, -1, depth])
logits = tf.matmul(q, k, transpose_b=True) # [..., length_q, length_kv]
if bias is not None:
bias = common_layers.cast_like(bias, logits)
with tf.name_scope("compute_area_att_bias", values=[bias]):
bias_shape = common_layers.shape_list(bias)
mem_length = bias_shape[-1]
bias_values = tf.reshape(
tf.to_float(tf.less(bias, -1)), [-1, mem_length, 1])
_, _, padding_sum, _, _ = compute_area_features(
bias_values, max_area_width=max_area_width,
max_area_height=max_area_height, height=memory_height)
bias = tf.where(
tf.cast(tf.to_int32(padding_sum), tf.bool),
tf.fill(tf.shape(padding_sum), -np.inf),
tf.zeros_like(padding_sum, dtype=tf.float32))
bias = tf.reshape(bias,
[bias_shape[0], bias_shape[1],
bias_shape[2], -1])
logits += bias
logits = logits / area_temperature
weights = tf.nn.softmax(logits, name="attention_weights")
if top_k_areas > 0:
tf.logging.info("area_attention top_k_areas=%d", top_k_areas)
top_k = tf.minimum(common_layers.shape_list(weights)[-1], top_k_areas)
top_weights, _ = tf.nn.top_k(weights, k=top_k)
min_values = tf.reduce_min(top_weights, -1, keepdims=True)
weights = tf.where(tf.greater_equal(weights, min_values),
weights, tf.zeros_like(weights))
weights = tf.div(weights, tf.reduce_sum(weights, -1, keepdims=True))
if save_weights_to is not None:
save_weights_to[scope.name] = weights
save_weights_to[scope.name + "/logits"] = logits
# Drop out attention links for each head.
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
if common_layers.should_generate_summaries() and attention_image_summary:
attention_image_summary(weights, image_shapes)
return tf.matmul(weights, v)
You can’t perform that action at this time.