Skip to content

Commit

Permalink
CLN: remove negative_to_zero argument
Browse files Browse the repository at this point in the history
  • Loading branch information
facaiy committed Aug 23, 2018
1 parent c05bb4e commit 01caf86
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 56 deletions.
7 changes: 3 additions & 4 deletions tensorflow/contrib/losses/python/losses/loss_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
return math_ops.div_no_nan(total_loss, num_present,
negative_to_zero=True, name="value")
return math_ops.div_no_nan(total_loss,
math_ops.maximum(num_present, 0),
name="value")


@deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
Expand Down Expand Up @@ -586,13 +587,11 @@ def mean_pairwise_squared_error(predictions,

term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch,
num_present_per_batch,
negative_to_zero=True,
name="value")

sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
term2 = 2.0 * math_ops.div_no_nan(math_ops.square(sum_diff),
math_ops.square(num_present_per_batch),
negative_to_zero=True,
name="value")

loss = _scale_losses(term1 - term2, weights)
Expand Down
8 changes: 2 additions & 6 deletions tensorflow/contrib/metrics/python/ops/metric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3189,11 +3189,9 @@ def streaming_covariance(predictions,
# batch_mean_prediction is E[x_B] in the update equation
batch_mean_prediction = math_ops.div_no_nan(
math_ops.reduce_sum(weighted_predictions), batch_count,
negative_to_zero=True,
name='batch_mean_prediction')
delta_mean_prediction = math_ops.div_no_nan(
(batch_mean_prediction - mean_prediction) * batch_count, update_count,
negative_to_zero=True,
name='delta_mean_prediction')
update_mean_prediction = state_ops.assign_add(mean_prediction,
delta_mean_prediction)
Expand All @@ -3203,11 +3201,9 @@ def streaming_covariance(predictions,
# batch_mean_label is E[y_B] in the update equation
batch_mean_label = math_ops.div_no_nan(
math_ops.reduce_sum(weighted_labels), batch_count,
negative_to_zero=True,
name='batch_mean_label')
delta_mean_label = math_ops.div_no_nan(
(batch_mean_label - mean_label) * batch_count, update_count,
negative_to_zero=True,
name='delta_mean_label')
update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
# prev_mean_label is E[y_A] in the update equation
Expand Down Expand Up @@ -3871,8 +3867,8 @@ def _calculate_k(po, pe_row, pe_col, name):
total = math_ops.reduce_sum(pe_row)
pe_sum = math_ops.reduce_sum(
math_ops.div_no_nan(
pe_row * pe_col, total,
negative_to_zero=True,
pe_row * pe_col,
math_ops.maximum(total, 0),
name=None))
po_sum, pe_sum, total = (math_ops.to_double(po_sum),
math_ops.to_double(pe_sum),
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/contrib/rate/rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,6 @@ def call(self, values, denominator):
state_ops.assign(self.prev_values, values)
state_ops.assign(self.prev_denominator, denominator)

return math_ops.div_no_nan(self.numer, self.denom,
negative_to_zero=True,
return math_ops.div_no_nan(self.numer,
math_op.maximum(self.denom, 0),
name="safe_rate")
4 changes: 2 additions & 2 deletions tensorflow/python/keras/engine/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,8 @@ def weighted(y_true, y_pred, weights, mask=None):
score_array = math_ops.multiply(score_array, weights)
score_array = math_ops.reduce_sum(score_array)
weights = math_ops.reduce_sum(weights)
score_array = math_ops.div_no_nan(score_array, weights,
negative_to_zero=True)
score_array = math_ops.div_no_nan(score_array,
math_ops.maximum(weights, 0))
return K.mean(score_array)

return weighted
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def update_state(self, values, sample_weight=None):
state_ops.assign_add(self.count, num_values)

def result(self):
return math_ops.div_no_nan(self.total, self.count, negative_to_zero=True)
return math_ops.div_no_nan(self.total, math_op.maximum(self.count, 0))


class MeanMetricWrapper(Mean):
Expand Down
18 changes: 10 additions & 8 deletions tensorflow/python/ops/losses/losses_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
return math_ops.div_no_nan(total_loss, num_present,
negative_to_zero=True, name="value")
return math_ops.div_no_nan(total_loss,
math_ops.maximum(num_present, 0),
name="value")


def _num_present(losses, weights, per_batch=False):
Expand Down Expand Up @@ -575,17 +576,18 @@ def mean_pairwise_squared_error(
keepdims=True)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)

term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch,
num_present_per_batch - 1,
negative_to_zero=True,
name="value")
term1 = 2.0 * math_ops.div_no_nan(
sum_squares_diff_per_batch,
math_ops.maximum(num_present_per_batch - 1, 0),
name="value")

sum_diff = math_ops.reduce_sum(
diffs, reduction_indices=reduction_indices, keepdims=True)
term2 = 2.0 * math_ops.div_no_nan(
math_ops.square(sum_diff),
math_ops.multiply(num_present_per_batch, num_present_per_batch - 1),
negative_to_zero=True,
math_ops.maximum(
math_ops.multiply(num_present_per_batch, num_present_per_batch - 1),
0),
name="value")

weighted_losses = math_ops.multiply(term1 - term2, weights)
Expand Down
5 changes: 1 addition & 4 deletions tensorflow/python/ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,14 +1039,13 @@ def div(x, y, name=None):


@tf_export("div_no_nan")
def div_no_nan(x, y, name=None, negative_to_zero=False):
def div_no_nan(x, y, name=None):
"""Computes an unsafe divide which returns 0 if the y is zero.
Args:
x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
y: A `Tensor` whose dtype is compatible with `x`.
name: A name for the operation (optional).
negative_to_zero: If `True`, negative is treated as zero in denominator.
Returns:
The element-wise value of the x divided by y.
"""
Expand All @@ -1059,8 +1058,6 @@ def div_no_nan(x, y, name=None, negative_to_zero=False):
if x_dtype != y_dtype:
raise TypeError("x and y must have the same dtype, got %r != %r" %
(x_dtype, y_dtype))
if negative_to_zero:
y = gen_math_ops.maximum(y, 0, name='negative_to_zero')
return gen_math_ops.div_no_nan(x, y, name=name)


Expand Down
13 changes: 0 additions & 13 deletions tensorflow/python/ops/math_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,19 +487,6 @@ def testBasic(self):
tf_result = math_ops.div_no_nan(nums, divs).eval()
self.assertAllEqual(tf_result, np_result)

def testNegativeToZero(self):
for dtype in [np.float32, np.float64]:
nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1)
divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)

np_result = np.true_divide(nums, divs)
np_result[:, divs[0] <= 0] = 0

with self.cached_session():
tf_result = math_ops.div_no_nan(nums, divs,
negative_to_zero=True).eval()
self.assertAllEqual(tf_result, np_result)


if __name__ == "__main__":
googletest.main()
33 changes: 17 additions & 16 deletions tensorflow/python/ops/metrics_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,13 @@ def mean(values,
update_count_op = state_ops.assign_add(count, num_values)

compute_mean = lambda _, t, c: math_ops.div_no_nan(
t, c, negative_to_zero=True, name='value')
t, math_ops.maximum(c, 0), name='value')

mean_t = _aggregate_across_towers(
metrics_collections, compute_mean, total, count)
update_op = math_ops.div_no_nan(update_total_op, update_count_op,
negative_to_zero=True, name='update_op')
update_op = math_ops.div_no_nan(update_total_op,
math_ops.maximum(update_count_op, 0),
name='update_op')

if updates_collections:
ops.add_to_collections(updates_collections, update_op)
Expand Down Expand Up @@ -756,21 +757,21 @@ def interpolate_pr_auc(tp, fp, fn):
"""
dtp = tp[:num_thresholds - 1] - tp[1:]
p = tp + fp
prec_slope = math_ops.div_no_nan(dtp, p[:num_thresholds - 1] - p[1:],
negative_to_zero=True,
name='prec_slope')
prec_slope = math_ops.div_no_nan(
dtp,
math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0),
name='prec_slope')
intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
safe_p_ratio = array_ops.where(
math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
math_ops.div_no_nan(p[:num_thresholds - 1], p[1:],
negative_to_zero=True,
math_ops.div_no_nan(p[:num_thresholds - 1],
math_ops.maximum(p[1:], 0),
name='recall_relative_ratio'),
array_ops.ones_like(p[1:]))
return math_ops.reduce_sum(
math_ops.div_no_nan(
prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
tp[1:] + fn[1:],
negative_to_zero=True,
math_ops.maximum(tp[1:] + fn[1:], 0),
name='pr_auc_increment'),
name='interpolate_pr_auc')

Expand Down Expand Up @@ -1052,16 +1053,16 @@ def mean_per_class_accuracy(labels,

def compute_mean_accuracy(_, count, total):
per_class_accuracy = math_ops.div_no_nan(
count, total, negative_to_zero=True, name=None)
count, math_ops.maximum(total, 0), name=None)
mean_accuracy_v = math_ops.reduce_mean(
per_class_accuracy, name='mean_accuracy')
return mean_accuracy_v

mean_accuracy_v = _aggregate_across_towers(
metrics_collections, compute_mean_accuracy, count, total)

update_op = math_ops.div_no_nan(update_count_op, update_total_op,
negative_to_zero=True,
update_op = math_ops.div_no_nan(update_count_op,
math_ops.maximum(update_total_op, 0),
name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
Expand Down Expand Up @@ -1372,13 +1373,13 @@ def mean_tensor(values,
update_count_op = state_ops.assign_add(count, num_values)

compute_mean = lambda _, t, c: math_ops.div_no_nan(
t, c, negative_to_zero=True, name='value')
t, math_ops.maximum(c, 0), name='value')

mean_t = _aggregate_across_towers(
metrics_collections, compute_mean, total, count)

update_op = math_ops.div_no_nan(update_total_op, update_count_op,
negative_to_zero=True,
update_op = math_ops.div_no_nan(update_total_op,
math_ops.maximum(update_count_op, 0),
name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
Expand Down

0 comments on commit 01caf86

Please sign in to comment.