Skip to content

Commit

Permalink
CLN: replace safe_div method by div_no_nan
Browse files Browse the repository at this point in the history
  • Loading branch information
facaiy committed Aug 22, 2018
1 parent 56ea7fc commit c05bb4e
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 163 deletions.
40 changes: 10 additions & 30 deletions tensorflow/contrib/losses/python/losses/loss_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,32 +66,6 @@ def _scale_losses(losses, weights):
return math_ops.reduce_sum(reduced_losses)


def _safe_div(numerator, denominator, name="value"):
"""Computes a safe divide which returns 0 if the denominator is zero.
Note that the function contains an additional conditional check that is
necessary for avoiding situations where the loss is zero causing NaNs to
creep into the gradient computation.
Args:
numerator: An arbitrary `Tensor`.
denominator: A `Tensor` whose shape matches `numerator` and whose values are
assumed to be non-negative.
name: An optional name for the returned op.
Returns:
The element-wise value of the numerator divided by the denominator.
"""
return array_ops.where(
math_ops.greater(denominator, 0),
math_ops.div(numerator,
array_ops.where(
math_ops.equal(denominator, 0),
array_ops.ones_like(denominator), denominator)),
array_ops.zeros_like(numerator),
name=name)


def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.
Expand All @@ -104,7 +78,8 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
return _safe_div(total_loss, num_present)
return math_ops.div_no_nan(total_loss, num_present,
negative_to_zero=True, name="value")


@deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
Expand Down Expand Up @@ -609,11 +584,16 @@ def mean_pairwise_squared_error(predictions,
math_ops.square(diffs), reduction_indices=reduction_indices)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)

term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch)
term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch,
num_present_per_batch,
negative_to_zero=True,
name="value")

sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
term2 = 2.0 * _safe_div(
math_ops.square(sum_diff), math_ops.square(num_present_per_batch))
term2 = 2.0 * math_ops.div_no_nan(math_ops.square(sum_diff),
math_ops.square(num_present_per_batch),
negative_to_zero=True,
name="value")

loss = _scale_losses(term1 - term2, weights)

Expand Down
46 changes: 18 additions & 28 deletions tensorflow/contrib/metrics/python/ops/metric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,6 @@
_EPSILON = 1e-7


def _safe_div(numerator, denominator, name):
"""Divides two values, returning 0 if the denominator is <= 0.
Args:
numerator: A real `Tensor`.
denominator: A real `Tensor`, with dtype matching `numerator`.
name: Name for the returned op.
Returns:
0 if `denominator` <= 0, else `numerator` / `denominator`
"""
return array_ops.where(
math_ops.greater(denominator, 0),
math_ops.truediv(numerator, denominator),
0,
name=name)


@deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the '
'order of the labels and predictions arguments has been switched.')
def streaming_true_positives(predictions,
Expand Down Expand Up @@ -3205,22 +3187,28 @@ def streaming_covariance(predictions,

# We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
# batch_mean_prediction is E[x_B] in the update equation
batch_mean_prediction = _safe_div(
batch_mean_prediction = math_ops.div_no_nan(
math_ops.reduce_sum(weighted_predictions), batch_count,
'batch_mean_prediction')
delta_mean_prediction = _safe_div(
negative_to_zero=True,
name='batch_mean_prediction')
delta_mean_prediction = math_ops.div_no_nan(
(batch_mean_prediction - mean_prediction) * batch_count, update_count,
'delta_mean_prediction')
negative_to_zero=True,
name='delta_mean_prediction')
update_mean_prediction = state_ops.assign_add(mean_prediction,
delta_mean_prediction)
# prev_mean_prediction is E[x_A] in the update equation
prev_mean_prediction = update_mean_prediction - delta_mean_prediction

# batch_mean_label is E[y_B] in the update equation
batch_mean_label = _safe_div(
math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label')
delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count,
update_count, 'delta_mean_label')
batch_mean_label = math_ops.div_no_nan(
math_ops.reduce_sum(weighted_labels), batch_count,
negative_to_zero=True,
name='batch_mean_label')
delta_mean_label = math_ops.div_no_nan(
(batch_mean_label - mean_label) * batch_count, update_count,
negative_to_zero=True,
name='delta_mean_label')
update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
# prev_mean_label is E[y_A] in the update equation
prev_mean_label = update_mean_label - delta_mean_label
Expand Down Expand Up @@ -3882,8 +3870,10 @@ def _calculate_k(po, pe_row, pe_col, name):
po_sum = math_ops.reduce_sum(po)
total = math_ops.reduce_sum(pe_row)
pe_sum = math_ops.reduce_sum(
metrics_impl._safe_div( # pylint: disable=protected-access
pe_row * pe_col, total, None))
math_ops.div_no_nan(
pe_row * pe_col, total,
negative_to_zero=True,
name=None))
po_sum, pe_sum, total = (math_ops.to_double(po_sum),
math_ops.to_double(pe_sum),
math_ops.to_double(total))
Expand Down
11 changes: 3 additions & 8 deletions tensorflow/contrib/rate/rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,6 @@ def name(self):
def variables(self):
return self._vars

def _safe_div(self, numerator, denominator, name):
t = math_ops.truediv(numerator, denominator)
zero = array_ops.zeros_like(t, dtype=denominator.dtype)
condition = math_ops.greater(denominator, zero)
zero = math_ops.cast(zero, t.dtype)
return array_ops.where(condition, t, zero, name=name)

def _add_variable(self, name, shape=None, dtype=None):
"""Private method for adding variables to the graph."""
if self._built:
Expand Down Expand Up @@ -148,4 +141,6 @@ def call(self, values, denominator):
state_ops.assign(self.prev_values, values)
state_ops.assign(self.prev_denominator, denominator)

return self._safe_div(self.numer, self.denom, name="safe_rate")
return math_ops.div_no_nan(self.numer, self.denom,
negative_to_zero=True,
name="safe_rate")
3 changes: 2 additions & 1 deletion tensorflow/python/keras/engine/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,8 @@ def weighted(y_true, y_pred, weights, mask=None):
score_array = math_ops.multiply(score_array, weights)
score_array = math_ops.reduce_sum(score_array)
weights = math_ops.reduce_sum(weights)
score_array = metrics_module.safe_div(score_array, weights)
score_array = math_ops.div_no_nan(score_array, weights,
negative_to_zero=True)
return K.mean(score_array)

return weighted
Expand Down
19 changes: 1 addition & 18 deletions tensorflow/python/keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,23 +136,6 @@ def merge_fn_wrapper(distribution, merge_fn, *args):
return tf_decorator.make_decorator(result_fn, decorated)


def safe_div(numerator, denominator):
"""Divides two tensors element-wise, returning 0 if the denominator is <= 0.
Args:
numerator: A `Tensor`.
denominator: A `Tensor`, with dtype matching `numerator`.
Returns:
0 if `denominator` <= 0, else `numerator` / `denominator`
"""
t = math_ops.truediv(numerator, denominator)
zero = array_ops.zeros_like(t, dtype=denominator.dtype)
condition = math_ops.greater(denominator, zero)
zero = math_ops.cast(zero, t.dtype)
return array_ops.where(condition, t, zero)


def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
"""Squeeze or expand last dimension if needed.
Expand Down Expand Up @@ -472,7 +455,7 @@ def update_state(self, values, sample_weight=None):
state_ops.assign_add(self.count, num_values)

def result(self):
return safe_div(self.total, self.count)
return math_ops.div_no_nan(self.total, self.count, negative_to_zero=True)


class MeanMetricWrapper(Mean):
Expand Down
14 changes: 0 additions & 14 deletions tensorflow/python/kernel_tests/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,11 @@
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.ops.losses import util
from tensorflow.python.platform import test
from tensorflow.python.training import momentum as momentum_lib


safe_div = losses_impl._safe_div # pylint: disable=protected-access


class SafeDivTest(test.TestCase):

def testEager(self):
with context.eager_mode():
self.assertAllEqual(safe_div(constant_op.constant(1.0),
constant_op.constant(0.0)), 0.0)
self.assertAllEqual(safe_div(constant_op.constant(1.0),
0.0), 0.0)


class AbsoluteDifferenceLossTest(test.TestCase):

def setUp(self):
Expand Down
40 changes: 10 additions & 30 deletions tensorflow/python/ops/losses/losses_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,31 +74,6 @@ def validate(cls, key):
raise ValueError("Invalid ReductionKey %s." % key)


def _safe_div(numerator, denominator, name="value"):
"""Computes a safe divide which returns 0 if the denominator is zero.
Note that the function contains an additional conditional check that is
necessary for avoiding situations where the loss is zero causing NaNs to
creep into the gradient computation.
Args:
numerator: An arbitrary `Tensor`.
denominator: `Tensor` whose shape matches `numerator` and whose values are
assumed to be non-negative.
name: An optional name for the returned op.
Returns:
The element-wise value of the numerator divided by the denominator.
"""
return array_ops.where(
math_ops.greater(denominator, 0),
math_ops.div(numerator, array_ops.where(
math_ops.equal(denominator, 0),
array_ops.ones_like(denominator), denominator)),
array_ops.zeros_like(numerator),
name=name)


def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.
Expand All @@ -111,7 +86,8 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
return _safe_div(total_loss, num_present)
return math_ops.div_no_nan(total_loss, num_present,
negative_to_zero=True, name="value")


def _num_present(losses, weights, per_batch=False):
Expand Down Expand Up @@ -599,14 +575,18 @@ def mean_pairwise_squared_error(
keepdims=True)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)

term1 = 2.0 * _safe_div(sum_squares_diff_per_batch,
num_present_per_batch - 1)
term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch,
num_present_per_batch - 1,
negative_to_zero=True,
name="value")

sum_diff = math_ops.reduce_sum(
diffs, reduction_indices=reduction_indices, keepdims=True)
term2 = 2.0 * _safe_div(
term2 = 2.0 * math_ops.div_no_nan(
math_ops.square(sum_diff),
math_ops.multiply(num_present_per_batch, num_present_per_batch - 1))
math_ops.multiply(num_present_per_batch, num_present_per_batch - 1),
negative_to_zero=True,
name="value")

weighted_losses = math_ops.multiply(term1 - term2, weights)
loss = math_ops.reduce_sum(weighted_losses)
Expand Down

0 comments on commit c05bb4e

Please sign in to comment.