Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion tensor2tensor/models/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs):
"""Conditional conv_fn making kernel 1d or 2d depending on inputs shape."""
static_shape = inputs.get_shape()
if not static_shape or len(static_shape) != 4:
raise ValueError("Inputs to conv must have statically known rank 4.")
raise ValueError("Inputs to conv must have statically known rank 4. Shape:" +str(static_shape))
# Add support for left padding.
if "padding" in kwargs and kwargs["padding"] == "LEFT":
dilation_rate = (1, 1)
Expand Down Expand Up @@ -1378,3 +1378,128 @@ def smoothing_cross_entropy(logits, labels, vocab_size, confidence):
xentropy = tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=soft_targets)
return xentropy - normalizing


def global_pool_1d(inputs, pooling_type='MAX', mask=None):
"""
Pools elements across the last dimension. Useful to a list of vectors into a
single vector to get a representation of a set.

Args
inputs: A tensor of dimensions batch_size x sequence_length x input_dims
containing the sequences of input vectors.
pooling_type: the pooling type to use, MAX or AVR
mask: A tensor of dimensions batch_size x sequence_length containing a
mask for the inputs with 1's for existing elements, and 0's elsewhere.
Outputs
output: A tensor of dimensions batch_size x input_dims
dimension containing the sequences of transformed vectors.
"""

with tf.name_scope("global_pool", [inputs]):
if mask is not None:
mask = tf.expand_dims(mask, axis=2)
inputs = tf.multiply(inputs, mask)

if pooling_type == 'MAX':
# A tf.pool can be used here, but reduce is cleaner
output = tf.reduce_max(inputs, axis=1)
elif pooling_type == 'AVR':
if mask is not None:
# Some elems are dummy elems so we can't just reduce the average
output = tf.reduce_sum(inputs, axis=1)
num_elems = tf.reduce_sum(mask, axis=1, keep_dims=True)
output = tf.div(output, num_elems)
#N.B: this will cause a NaN if one batch contains no elements
else:
output = tf.reduce_mean(inputs, axis=1)

return output


def linear_set_layer(layer_size,
inputs,
context=None,
activation_fn=tf.nn.relu,
dropout=0.0,
name=None):
"""
Basic layer type for doing funky things with sets.
Applies a linear transformation to each element in the input set.
If a context is supplied, it is concatenated with the inputs.
e.g. One can use global_pool_1d to get a representation of the set which
can then be used as the context for the next layer.

Args
layer_size: Dimension to transform the input vectors to
inputs: A tensor of dimensions batch_size x sequence_length x input_dims
containing the sequences of input vectors.
context: A tensor of dimensions batch_size x context_dims
containing a global statistic about the set.
dropout: Dropout probability.
activation_fn: The activation function to use.
Outputs
output: A tensor of dimensions batch_size x sequence_length x output_dims
dimension containing the sequences of transformed vectors.

TODO: Add bias add.
"""

with tf.variable_scope(name, "linear_set_layer", [inputs]):
# Apply 1D convolution to apply linear filter to each element along the 2nd
# dimension
#in_size = inputs.get_shape().as_list()[-1]
outputs = conv1d(inputs, layer_size, 1, activation=None, name="set_conv")

# Apply the context if it exists
if context is not None:
# Unfortunately tf doesn't support broadcasting via concat, but we can
# simply add the transformed context to get the same effect
context = tf.expand_dims(context, axis=1)
#context_size = context.get_shape().as_list()[-1]
cont_tfm = conv1d(context, layer_size, 1,
activation=None, name="cont_conv")
outputs += cont_tfm

if activation_fn is not None:
outputs = activation_fn(outputs)

if dropout != 0.0:
output = tf.nn.dropout(output, 1.0 - dropout)

return outputs


def ravanbakhsh_set_layer(layer_size,
inputs,
mask=None,
activation_fn=tf.nn.tanh,
dropout=0.0,
name=None):
"""
Layer from Deep Sets paper: https://arxiv.org/abs/1611.04500
More parameter-efficient verstion of a linear-set-layer with context.


Args
layer_size: Dimension to transform the input vectors to.
inputs: A tensor of dimensions batch_size x sequence_length x vector
containing the sequences of input vectors.
mask: A tensor of dimensions batch_size x sequence_length containing a
mask for the inputs with 1's for existing elements, and 0's elsewhere.
activation_fn: The activation function to use.
Outputs
output: A tensor of dimensions batch_size x sequence_length x vector
dimension containing the sequences of transformed vectors.
"""

with tf.variable_scope(name, "ravanbakhsh_set_layer", [inputs]):
output = linear_set_layer(
layer_size,
inputs - tf.expand_dims(global_pool_1d(inputs, mask=mask), axis=1),
activation_fn=activation_fn,
name=name)

return output


72 changes: 70 additions & 2 deletions tensor2tensor/models/common_layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def testSaturatingSigmoid(self):
self.assertAllClose(res, [0.0, 0.0, 0.5, 1.0, 1.0])

def testFlatten4D3D(self):
x = np.random.random_integers(1, high=8, size=(3, 5, 2))
x = np.random.randint(1, 9, size=(3, 5, 2))
with self.test_session() as session:
y = common_layers.flatten4d3d(common_layers.embedding(x, 10, 7))
session.run(tf.global_variables_initializer())
res = session.run(y)
self.assertEqual(res.shape, (3, 5 * 2, 7))

def testEmbedding(self):
x = np.random.random_integers(1, high=8, size=(3, 5))
x = np.random.randint(1, 9, size=(3, 5))
with self.test_session() as session:
y = common_layers.embedding(x, 10, 16)
session.run(tf.global_variables_initializer())
Expand All @@ -81,6 +81,14 @@ def testConv(self):
session.run(tf.global_variables_initializer())
res = session.run(y)
self.assertEqual(res.shape, (5, 5, 1, 13))

def testConv1d(self):
x = np.random.rand(5, 7, 11)
with self.test_session() as session:
y = common_layers.conv1d(tf.constant(x, dtype=tf.float32), 13, 1)
session.run(tf.global_variables_initializer())
res = session.run(y)
self.assertEqual(res.shape, (5, 7, 13))

def testSeparableConv(self):
x = np.random.rand(5, 7, 1, 11)
Expand Down Expand Up @@ -293,6 +301,66 @@ def testDeconvStride2MultiStep(self):
session.run(tf.global_variables_initializer())
actual = session.run(a)
self.assertEqual(actual.shape, (5, 32, 1, 16))

def testGlobalPool1d(self):
shape = (5, 4)
x1 = np.random.rand(5,4,11)
#mask = np.random.randint(2, size=shape)
no_mask = np.ones((5,4))
full_mask = np.zeros((5,4))

with self.test_session() as session:
x1_ = tf.Variable(x1, dtype=tf.float32)
no_mask_ = tf.Variable(no_mask, dtype=tf.float32)
full_mask_ = tf.Variable(full_mask, dtype=tf.float32)

none_mask_max = common_layers.global_pool_1d(x1_)
no_mask_max = common_layers.global_pool_1d(x1_, mask=no_mask_)
result1 = tf.reduce_sum(none_mask_max - no_mask_max)

full_mask_max = common_layers.global_pool_1d(x1_, mask=full_mask_)
result2 = tf.reduce_sum(full_mask_max)

none_mask_avr = common_layers.global_pool_1d(x1_, 'AVR')
no_mask_avr = common_layers.global_pool_1d(x1_, 'AVR', no_mask_)
result3 = tf.reduce_sum(none_mask_avr - no_mask_avr)

full_mask_avr = common_layers.global_pool_1d(x1_, 'AVR', full_mask_)
result4 = tf.reduce_sum(full_mask_avr)

session.run(tf.global_variables_initializer())
actual = session.run([result1, result2, result3, result4])
# N.B: Last result will give a NaN.
self.assertAllEqual(actual[:3], [0.0, 0.0, 0.0])


def testLinearSetLayer(self):
x1 = np.random.rand(5,4,11)
cont = np.random.rand(5,13)
with self.test_session() as session:
x1_ = tf.Variable(x1, dtype=tf.float32)
cont_ = tf.Variable(cont, dtype=tf.float32)

simple_ff = common_layers.linear_set_layer(32, x1_)
cont_ff = common_layers.linear_set_layer(32, x1_, context=cont_)

session.run(tf.global_variables_initializer())
actual = session.run([simple_ff, cont_ff])
self.assertEqual(actual[0].shape, (5,4,32))
self.assertEqual(actual[1].shape, (5,4,32))

def testRavanbakhshSetLayer(self):
x1 = np.random.rand(5,4,11)
cont = np.random.rand(5,13)
with self.test_session() as session:
x1_ = tf.Variable(x1, dtype=tf.float32)
cont_ = tf.Variable(cont, dtype=tf.float32)

layer = common_layers.ravanbakhsh_set_layer(32, x1_)

session.run(tf.global_variables_initializer())
actual = session.run(layer)
self.assertEqual(actual.shape, (5,4,32))


if __name__ == "__main__":
Expand Down
Loading