Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean all those safe_div, _safe_div methods #21798

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 7 additions & 30 deletions tensorflow/contrib/losses/python/losses/loss_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,32 +66,6 @@ def _scale_losses(losses, weights):
return math_ops.reduce_sum(reduced_losses)


def _safe_div(numerator, denominator, name="value"):
"""Computes a safe divide which returns 0 if the denominator is zero.

Note that the function contains an additional conditional check that is
necessary for avoiding situations where the loss is zero causing NaNs to
creep into the gradient computation.

Args:
numerator: An arbitrary `Tensor`.
denominator: A `Tensor` whose shape matches `numerator` and whose values are
assumed to be non-negative.
name: An optional name for the returned op.

Returns:
The element-wise value of the numerator divided by the denominator.
"""
return array_ops.where(
math_ops.greater(denominator, 0),
math_ops.div(numerator,
array_ops.where(
math_ops.equal(denominator, 0),
array_ops.ones_like(denominator), denominator)),
array_ops.zeros_like(numerator),
name=name)


def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.

Expand All @@ -104,7 +78,7 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
return _safe_div(total_loss, num_present)
return math_ops.div_no_nan(total_loss, num_present, name="value")


@deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
Expand Down Expand Up @@ -609,11 +583,14 @@ def mean_pairwise_squared_error(predictions,
math_ops.square(diffs), reduction_indices=reduction_indices)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)

term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch)
term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch,
num_present_per_batch,
name="value")

sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
term2 = 2.0 * _safe_div(
math_ops.square(sum_diff), math_ops.square(num_present_per_batch))
term2 = 2.0 * math_ops.div_no_nan(math_ops.square(sum_diff),
math_ops.square(num_present_per_batch),
name="value")

loss = _scale_losses(term1 - term2, weights)

Expand Down
50 changes: 20 additions & 30 deletions tensorflow/contrib/metrics/python/ops/metric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,6 @@
_EPSILON = 1e-7


def _safe_div(numerator, denominator, name):
"""Divides two values, returning 0 if the denominator is <= 0.

Args:
numerator: A real `Tensor`.
denominator: A real `Tensor`, with dtype matching `numerator`.
name: Name for the returned op.

Returns:
0 if `denominator` <= 0, else `numerator` / `denominator`
"""
return array_ops.where(
math_ops.greater(denominator, 0),
math_ops.truediv(numerator, denominator),
0,
name=name)


@deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the '
'order of the labels and predictions arguments has been switched.')
def streaming_true_positives(predictions,
Expand Down Expand Up @@ -3238,22 +3220,28 @@ def streaming_covariance(predictions,

# We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
# batch_mean_prediction is E[x_B] in the update equation
batch_mean_prediction = _safe_div(
math_ops.reduce_sum(weighted_predictions), batch_count,
'batch_mean_prediction')
delta_mean_prediction = _safe_div(
(batch_mean_prediction - mean_prediction) * batch_count, update_count,
'delta_mean_prediction')
batch_mean_prediction = math_ops.div_no_nan(
math_ops.reduce_sum(weighted_predictions),
batch_count,
name='batch_mean_prediction')
delta_mean_prediction = math_ops.div_no_nan(
(batch_mean_prediction - mean_prediction) * batch_count,
update_count,
name='delta_mean_prediction')
update_mean_prediction = state_ops.assign_add(mean_prediction,
delta_mean_prediction)
# prev_mean_prediction is E[x_A] in the update equation
prev_mean_prediction = update_mean_prediction - delta_mean_prediction

# batch_mean_label is E[y_B] in the update equation
batch_mean_label = _safe_div(
math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label')
delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count,
update_count, 'delta_mean_label')
batch_mean_label = math_ops.div_no_nan(
math_ops.reduce_sum(weighted_labels),
batch_count,
name='batch_mean_label')
delta_mean_label = math_ops.div_no_nan(
(batch_mean_label - mean_label) * batch_count,
update_count,
name='delta_mean_label')
update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
# prev_mean_label is E[y_A] in the update equation
prev_mean_label = update_mean_label - delta_mean_label
Expand Down Expand Up @@ -3915,8 +3903,10 @@ def _calculate_k(po, pe_row, pe_col, name):
po_sum = math_ops.reduce_sum(po)
total = math_ops.reduce_sum(pe_row)
pe_sum = math_ops.reduce_sum(
metrics_impl._safe_div( # pylint: disable=protected-access
pe_row * pe_col, total, None))
math_ops.div_no_nan(
math_ops.to_double(pe_row * pe_col),
math_ops.to_double(total),
name=None))
po_sum, pe_sum, total = (math_ops.to_double(po_sum),
math_ops.to_double(pe_sum),
math_ops.to_double(total))
Expand Down
11 changes: 3 additions & 8 deletions tensorflow/contrib/rate/rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,6 @@ def name(self):
def variables(self):
return self._vars

def _safe_div(self, numerator, denominator, name):
t = math_ops.truediv(numerator, denominator)
zero = array_ops.zeros_like(t, dtype=denominator.dtype)
condition = math_ops.greater(denominator, zero)
zero = math_ops.cast(zero, t.dtype)
return array_ops.where(condition, t, zero, name=name)

def _add_variable(self, name, shape=None, dtype=None):
"""Private method for adding variables to the graph."""
if self._built:
Expand Down Expand Up @@ -148,4 +141,6 @@ def call(self, values, denominator):
state_ops.assign(self.prev_values, values)
state_ops.assign(self.prev_denominator, denominator)

return self._safe_div(self.numer, self.denom, name="safe_rate")
return math_ops.div_no_nan(self.numer,
math_ops.maximum(self.denom, 0),
name="safe_rate")
2 changes: 1 addition & 1 deletion tensorflow/python/keras/engine/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def weighted(y_true, y_pred, weights, mask=None):
score_array = math_ops.multiply(score_array, weights)
score_array = math_ops.reduce_sum(score_array)
weights = math_ops.reduce_sum(weights)
score_array = metrics_module.safe_div(score_array, weights)
score_array = math_ops.div_no_nan(score_array, weights)
return K.mean(score_array)

return weighted
Expand Down
19 changes: 1 addition & 18 deletions tensorflow/python/keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,23 +155,6 @@ def inner(*args, **kwargs):
return inner


def safe_div(numerator, denominator):
"""Divides two tensors element-wise, returning 0 if the denominator is <= 0.

Args:
numerator: A `Tensor`.
denominator: A `Tensor`, with dtype matching `numerator`.

Returns:
0 if `denominator` <= 0, else `numerator` / `denominator`
"""
t = math_ops.truediv(numerator, denominator)
zero = array_ops.zeros_like(t, dtype=denominator.dtype)
condition = math_ops.greater(denominator, zero)
zero = math_ops.cast(zero, t.dtype)
return array_ops.where(condition, t, zero)


def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
"""Squeeze or expand last dimension if needed.

Expand Down Expand Up @@ -505,7 +488,7 @@ def update_state(self, values, sample_weight=None):
state_ops.assign_add(self.count, num_values)

def result(self):
return safe_div(self.total, self.count)
return math_ops.div_no_nan(self.total, self.count)


class MeanMetricWrapper(Mean):
Expand Down
15 changes: 0 additions & 15 deletions tensorflow/python/kernel_tests/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import numpy as np

from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
Expand All @@ -34,25 +33,11 @@
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.ops.losses import util
from tensorflow.python.platform import test
from tensorflow.python.training import momentum as momentum_lib


safe_div = losses_impl._safe_div # pylint: disable=protected-access


class SafeDivTest(test.TestCase):

def testEager(self):
with context.eager_mode():
self.assertAllEqual(safe_div(constant_op.constant(1.0),
constant_op.constant(0.0)), 0.0)
self.assertAllEqual(safe_div(constant_op.constant(1.0),
0.0), 0.0)


class AbsoluteDifferenceLossTest(test.TestCase):

def setUp(self):
Expand Down
41 changes: 11 additions & 30 deletions tensorflow/python/ops/losses/losses_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,31 +74,6 @@ def validate(cls, key):
raise ValueError("Invalid ReductionKey %s." % key)


def _safe_div(numerator, denominator, name="value"):
"""Computes a safe divide which returns 0 if the denominator is zero.

Note that the function contains an additional conditional check that is
necessary for avoiding situations where the loss is zero causing NaNs to
creep into the gradient computation.

Args:
numerator: An arbitrary `Tensor`.
denominator: `Tensor` whose shape matches `numerator` and whose values are
assumed to be non-negative.
name: An optional name for the returned op.

Returns:
The element-wise value of the numerator divided by the denominator.
"""
return array_ops.where(
math_ops.greater(denominator, 0),
math_ops.div(numerator, array_ops.where(
math_ops.equal(denominator, 0),
array_ops.ones_like(denominator), denominator)),
array_ops.zeros_like(numerator),
name=name)


def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.

Expand All @@ -111,7 +86,7 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
return _safe_div(total_loss, num_present)
return math_ops.div_no_nan(total_loss, num_present, name="value")


def _num_present(losses, weights, per_batch=False):
Expand Down Expand Up @@ -599,14 +574,20 @@ def mean_pairwise_squared_error(
keepdims=True)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)

term1 = 2.0 * _safe_div(sum_squares_diff_per_batch,
num_present_per_batch - 1)
term1 = 2.0 * math_ops.div_no_nan(
sum_squares_diff_per_batch,
math_ops.maximum(num_present_per_batch - 1, 0),
name="value")

sum_diff = math_ops.reduce_sum(
diffs, reduction_indices=reduction_indices, keepdims=True)
term2 = 2.0 * _safe_div(
term2 = 2.0 * math_ops.div_no_nan(
math_ops.square(sum_diff),
math_ops.multiply(num_present_per_batch, num_present_per_batch - 1))
math_ops.maximum(
math_ops.multiply(num_present_per_batch,
num_present_per_batch - 1),
0),
name="value")

weighted_losses = math_ops.multiply(term1 - term2, weights)
loss = math_ops.reduce_sum(weighted_losses)
Expand Down