Skip to content

Commit

Permalink
Extend softmax and logsoftmax to make them work on an arbitrary dimen…
Browse files Browse the repository at this point in the history
…sion of a non-scalar tensor.

Change: 131540860
  • Loading branch information
Yuefeng Zhou authored and tensorflower-gardener committed Aug 28, 2016
1 parent cf35735 commit aeac274
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu):
with tf.name_scope('Wx_plus_b'):
preactivate = tf.matmul(input_tensor, weights) + biases
tf.histogram_summary(layer_name + '/pre_activations', preactivate)
activations = act(preactivate, 'activation')
activations = act(preactivate, name='activation')
tf.histogram_summary(layer_name + '/activations', activations)
return activations

Expand Down
90 changes: 64 additions & 26 deletions tensorflow/python/kernel_tests/softmax_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,39 @@

class SoftmaxTest(tf.test.TestCase):

def _npSoftmax(self, features, log=False):
batch_dim = 0
class_dim = 1
batch_size = features.shape[batch_dim]
e = np.exp(features -
np.reshape(np.amax(features, axis=class_dim), [batch_size, 1]))
softmax = e / np.reshape(np.sum(e, axis=class_dim), [batch_size, 1])
def _npSoftmax(self, features, dim=-1, log=False):
if dim is -1:
dim = len(features.shape) - 1
one_only_on_dim = list(features.shape)
one_only_on_dim[dim] = 1
e = np.exp(features - np.reshape(
np.amax(
features, axis=dim), one_only_on_dim))
softmax = e / np.reshape(np.sum(e, axis=dim), one_only_on_dim)
if log:
return np.log(softmax)
else:
return softmax

def _testSoftmax(self, np_features, log=False, use_gpu=False):
def _testSoftmax(self, np_features, dim=-1, log=False, use_gpu=False):
# A previous version of the code checked the op name rather than the op type
# to distinguish between log and non-log. Use an arbitrary name to catch
# this bug in future.
name = "arbitrary"
np_softmax = self._npSoftmax(np_features, log=log)
np_softmax = self._npSoftmax(np_features, dim=dim, log=log)
with self.test_session(use_gpu=use_gpu):
if log:
tf_softmax = tf.nn.log_softmax(np_features, name=name)
tf_softmax = tf.nn.log_softmax(np_features, dim=dim, name=name)
else:
tf_softmax = tf.nn.softmax(np_features, name=name)
tf_softmax = tf.nn.softmax(np_features, dim=dim, name=name)
out = tf_softmax.eval()
self.assertAllCloseAccordingToType(np_softmax, out)
self.assertShapeEqual(np_softmax, tf_softmax)
if not log:
# Bonus check: the softmaxes should add to one in each
# batch element.
self.assertAllCloseAccordingToType(np.ones(out.shape[0]),
np.sum(out, axis=1))
# Bonus check: the softmaxes should add to one in dimension dim.
sum_along_dim = np.sum(out, axis=dim)
self.assertAllCloseAccordingToType(
np.ones(sum_along_dim.shape), sum_along_dim)

def _testAll(self, features):
self._testSoftmax(features, use_gpu=False)
Expand Down Expand Up @@ -90,17 +92,11 @@ def testNpSoftmax(self):
np_lsm,
rtol=1.e-5, atol=1.e-5)

def testShapeMismatch(self):
with self.assertRaises(ValueError):
tf.nn.softmax([0., 1., 2., 3.])
with self.assertRaises(ValueError):
tf.nn.log_softmax([0., 1., 2., 3.])

def _testOverflow(self, use_gpu=False):
if use_gpu:
type = np.float32
type = np.float32
else:
type = np.float64
type = np.float64
max = np.finfo(type).max
features = np.array(
[[1., 1., 1., 1.],
Expand Down Expand Up @@ -128,13 +124,55 @@ def testDouble(self):
use_gpu=False)
self._testOverflow(use_gpu=False)

def test1DTesnorAsInput(self):
self._testSoftmax(
np.array([3., 2., 3., 9.]).astype(np.float64), use_gpu=False)
self._testOverflow(use_gpu=False)

def test3DTensorAsInput(self):
self._testSoftmax(
np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
[[2., 3., 4., 5.], [6., 7., 8., 9.]],
[[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32),
use_gpu=False)
self._testOverflow(use_gpu=False)

def testAlongFirstDimension(self):
self._testSoftmax(
np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
[[2., 3., 4., 5.], [6., 7., 8., 9.]],
[[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32),
dim=0,
use_gpu=False)
self._testOverflow(use_gpu=False)

def testAlongSecondDimension(self):
self._testSoftmax(
np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
[[2., 3., 4., 5.], [6., 7., 8., 9.]],
[[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32),
dim=1,
use_gpu=False)
self._testOverflow(use_gpu=False)

def testShapeInference(self):
op = tf.nn.softmax([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
[[2., 3., 4., 5.], [6., 7., 8., 9.]],
[[5., 4., 3., 2.], [1., 2., 3., 4.]]])
self.assertEqual([3, 2, 4], op.get_shape())

def testEmpty(self):
def testEmptyInput(self):
with self.test_session():
x = tf.constant([[]], shape=[0, 3])
self.assertEqual(0, tf.size(x).eval())
expected_y = np.array([]).reshape(0, 3)
np.testing.assert_array_equal(expected_y, tf.nn.softmax(x).eval())
# reshape would raise if logits is empty
with self.assertRaises(tf.errors.InvalidArgumentError):
tf.nn.softmax(x, dim=0).eval()

def testDimTooLarge(self):
with self.test_session():
with self.assertRaises(tf.errors.InvalidArgumentError):
tf.nn.softmax([1., 2., 3., 4.], dim=100).eval()


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/python/ops/hidden_ops.txt
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ BiasAddV1
Relu6
AvgPool
MaxPool
Softmax
LogSoftmax

# parsing_ops
ParseExample
Expand Down
129 changes: 127 additions & 2 deletions tensorflow/python/ops/nn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,129 @@ def relu6(features, name=None):
features = ops.convert_to_tensor(features, name="features")
return gen_nn_ops._relu6(features, name=name)

def _softmax(logits, compute_op, dim=-1, name=None):
"""Helper function for softmax and log_softmax.
It reshapes and transposes the input logits into a 2-D Tensor and then invokes
the tf.nn._softmax or tf.nn._log_softmax function. The output would be
transposed and reshaped back.
Args:
logits: A non-empty `Tensor`. Must be one of the following types: `half`,
`float32`, `float64`.
compute_op: Either gen_nn_ops._softmax or gen_nn_ops._log_softmax
dim: The dimension softmax would be performed on. The default is -1 which
indicates the last dimension.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
Raises:
InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
dimension of `logits`.
"""
# Helper function to swap dim_index and last_index of logits. last_index must
# be logits' last dimension.
def _swap_axis(logits, dim_index, last_index):
return array_ops.transpose(logits, array_ops.concat(
0, [math_ops.range(dim_index), [last_index],
math_ops.range(dim_index + 1, last_index), [dim_index]]))

# Helper function to flatten logits' outer dimensions and keep its last
# dimension.
def _flatten_outer_dims(logits):
rank = array_ops.rank(logits)
last_dim_size = array_ops.slice(
array_ops.shape(logits), [math_ops.sub(rank, 1)], [1])
return array_ops.reshape(logits, array_ops.concat(0, [[-1], last_dim_size]))

logits = ops.convert_to_tensor(logits)
if logits.get_shape().ndims is 2 and dim is -1:
return compute_op(logits, name=name)

# We need its original shape for shape inference.
shape = logits.get_shape()

# If dim is the last dimension, simply reshape the logits to a matrix and
# apply the internal softmax.
if dim is -1:
input_shape = array_ops.shape(logits)
logits = _flatten_outer_dims(logits)
output = compute_op(logits, name=name)
output = array_ops.reshape(output, input_shape)
return output

# If dim is not the last dimension, we have to do a reshape and transpose so
# that we can still perform softmax on its last dimension.

# Swap logits' dimension of dim and its last dimension.
input_rank = array_ops.rank(logits)
logits = _swap_axis(logits, dim, math_ops.sub(input_rank, 1))
shape_after_swap = array_ops.shape(logits)

# Reshape logits into a matrix.
logits = _flatten_outer_dims(logits)

# Do the actual softmax on its last dimension.
output = compute_op(logits, name=name)

# Transform back the output tensor.
output = array_ops.reshape(output, shape_after_swap)
output = _swap_axis(output, dim, math_ops.sub(input_rank, 1))

# Make shape inference work since reshape and transpose may erase its static
# shape.
output.set_shape(shape)

return output


def softmax(logits, dim=-1, name=None):
"""Computes log softmax activations.
For each batch `i` and class `j` we have
softmax = exp(logits) / reduce_sum(exp(logits), dim)
Args:
logits: A non-empty `Tensor`. Must be one of the following types: `half`,
`float32`, `float64`.
dim: The dimension softmax would be performed on. The default is -1 which
indicates the last dimension.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
Raises:
InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
dimension of `logits`.
"""
return _softmax(logits, gen_nn_ops._softmax, dim, name)


def log_softmax(logits, dim=-1, name=None):
"""Computes log softmax activations.
For each batch `i` and class `j` we have
logsoftmax = logits - reduce_sum(exp(logits), dim)
Args:
logits: A non-empty `Tensor`. Must be one of the following types: `half`,
`float32`, `float64`.
dim: The dimension softmax would be performed on. The default is -1 which
indicates the last dimension.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
Raises:
InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
dimension of `logits`.
"""
return _softmax(logits, gen_nn_ops._log_softmax, dim, name)


def softmax_cross_entropy_with_logits(logits, labels, name=None):
"""Computes softmax cross entropy between `logits` and `labels`.
Expand Down Expand Up @@ -727,9 +850,11 @@ def _LRNGradShape(op):
return [in_grads_shape.merge_with(in_image_shape).merge_with(out_image_shape)]


ops.RegisterShape("Softmax")(common_shapes.unchanged_shape_with_rank(2))
ops.RegisterShape("Softmax")(common_shapes.unchanged_shape_with_rank_at_least(
1))

ops.RegisterShape("LogSoftmax")(common_shapes.unchanged_shape_with_rank(2))
ops.RegisterShape("LogSoftmax")(
common_shapes.unchanged_shape_with_rank_at_least(1))


@ops.RegisterShape("InTopK")
Expand Down

0 comments on commit aeac274

Please sign in to comment.