Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sampled version of sparse_softmax_cross_entropy_with_logits #13453

Merged
merged 6 commits into from
Dec 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/contrib/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
@@deprecated_flipped_sparse_softmax_cross_entropy_with_logits
@@deprecated_flipped_sigmoid_cross_entropy_with_logits
@@rank_sampled_softmax_loss
@@sampled_sparse_softmax_loss
@@scaled_softplus
"""

Expand Down
100 changes: 100 additions & 0 deletions tensorflow/contrib/nn/python/ops/sampling_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import nn_ops


def _rank_resample(weights, biases, inputs, sampled_values, num_resampled,
Expand Down Expand Up @@ -240,3 +242,101 @@ def rank_sampled_softmax_loss(weights,
remove_accidental_hits=remove_accidental_hits,
partition_strategy=partition_strategy,
name=name)


def sampled_sparse_softmax_loss(weights,
biases,
labels,
inputs,
num_sampled,
num_classes,
sampled_values=None,
remove_accidental_hits=True,
partition_strategy="mod",
name="sampled_sparse_softmax_loss"):
"""Computes and returns the sampled sparse softmax training loss.

This is a faster way to train a softmax classifier over a huge number of
classes.

This operation is for training only. It is generally an underestimate of
the full softmax loss.

A common use case is to use this method for training, and calculate the full
softmax loss for evaluation or inference. In this case, you must set
`partition_strategy="div"` for the two losses to be consistent, as in the
following example:

```python
if mode == "train":
loss = tf.nn.sampled_sparse_softmax_loss(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
...,
partition_strategy="div")
elif mode == "eval":
logits = tf.matmul(inputs, tf.transpose(weights))
logits = tf.nn.bias_add(logits, biases)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.squeeze(labels),
logits=logits)
```

See our [Candidate Sampling Algorithms Reference]
(https://www.tensorflow.org/extras/candidate_sampling.pdf)

Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.

Args:
weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
objects whose concatenation along dimension 0 has shape
[num_classes, dim]. The (possibly-sharded) class embeddings.
biases: A `Tensor` of shape `[num_classes]`. The class biases.
labels: A `Tensor` of type `int64` and shape `[batch_size, 1]`.
The index of the single target class for each row of logits. Note that
this format differs from the `labels` argument of
`nn.sparse_softmax_cross_entropy_with_logits`.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
num_sampled: An `int`. The number of classes to randomly sample per batch.
num_classes: An `int`. The number of possible classes.
sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
`sampled_expected_count`) returned by a `*_candidate_sampler` function.
(if None, we default to `log_uniform_candidate_sampler`)
remove_accidental_hits: A `bool`. whether to remove "accidental hits"
where a sampled class equals one of the target classes. Default is
True.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
name: A name for the operation (optional).

Returns:
A `batch_size` 1-D tensor of per-example sampled softmax losses.

"""
logits, _ = nn_impl._compute_sampled_logits(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
num_sampled=num_sampled,
num_classes=num_classes,
num_true=1,
sampled_values=sampled_values,
subtract_log_q=True,
remove_accidental_hits=remove_accidental_hits,
partition_strategy=partition_strategy,
name=name)

# There is only one true label. _compute_sampled_logits puts the true logit
# at index 0.
labels = array_ops.zeros([array_ops.shape(logits)[0], 1], dtype=dtypes.int64)

sampled_losses = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=array_ops.squeeze(labels), logits=logits)
# sampled_losses is a [batch_size] tensor.
return sampled_losses
13 changes: 8 additions & 5 deletions tensorflow/python/ops/nn_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import candidate_sampling_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
Expand Down Expand Up @@ -932,10 +933,11 @@ class biases.
Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
name: A name for the operation (optional).
Returns:
out_logits, out_labels: `Tensor` objects each with shape
out_logits: `Tensor` object with shape
`[batch_size, num_true + num_sampled]`, for passing to either
`nn.sigmoid_cross_entropy_with_logits` (NCE) or
`nn.softmax_cross_entropy_with_logits` (sampled softmax).
out_labels: A Tensor object with the same shape as `out_logits`.
"""

if isinstance(weights, variables.PartitionedVariable):
Expand Down Expand Up @@ -1046,15 +1048,16 @@ class biases.

# Construct output logits and labels. The true labels/logits start at col 0.
out_logits = array_ops.concat([true_logits, sampled_logits], 1)
# true_logits is a float tensor, ones_like(true_logits) is a float tensor
# of ones. We then divide by num_true to ensure the per-example labels sum
# to 1.0, i.e. form a proper probability distribution.

# true_logits is a float tensor, ones_like(true_logits) is a float
# tensor of ones. We then divide by num_true to ensure the per-example
# labels sum to 1.0, i.e. form a proper probability distribution.
out_labels = array_ops.concat([
array_ops.ones_like(true_logits) / num_true,
array_ops.zeros_like(sampled_logits)
], 1)

return out_logits, out_labels
return out_logits, out_labels


def nce_loss(weights,
Expand Down