Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.clip_by_global_norm raise exception when use_norm = inf #21428

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions tensorflow/python/kernel_tests/clip_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
Expand Down Expand Up @@ -369,6 +370,21 @@ def testClipByGlobalNormZero(self):
self.assertAllClose(np_ans_0, tf_ans_1)
self.assertAllClose(np_ans_1, tf_ans_2)

def testClipByGlobalNormInf(self):
with self.test_session(use_gpu=True):
x0 = constant_op.constant([-2.0, 0.0, np.inf, 4.0, 0.0, 0.0],
shape=[2, 3])
x1 = constant_op.constant([1.0, -2.0])
clip_norm = 6.0

ans, norm = clip_ops.clip_by_global_norm([x0, x1], clip_norm)
with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"):
norm.eval()
with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"):
ans[0].eval()
with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"):
ans[1].eval()

def testClipByAverageNormClipped(self):
# Norm clipping when average clip_norm < 0.83333333
with self.test_session(use_gpu=True):
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/python/ops/clip_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import numerics
from tensorflow.python.util.tf_export import tf_export


Expand All @@ -54,7 +55,7 @@ def clip_by_value(t, clip_value_min, clip_value_max,
A clipped `Tensor`.

Raises:
ValueError: if the clip tensors would trigger array broadcasting
ValueError: If the clip tensors would trigger array broadcasting
that would make the returned tensor larger than the input.
"""
with ops.name_scope(name, "clip_by_value",
Expand Down Expand Up @@ -243,13 +244,16 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

Raises:
TypeError: If `t_list` is not a sequence.
InvalidArgumentError: If global norm is not finite.
"""
if (not isinstance(t_list, collections.Sequence)
or isinstance(t_list, six.string_types)):
raise TypeError("t_list should be a sequence")
t_list = list(t_list)
if use_norm is None:
use_norm = global_norm(t_list, name)
use_norm = numerics.verify_tensor_all_finite(use_norm,
"Found Inf or NaN global norm.")

with ops.name_scope(name, "clip_by_global_norm",
t_list + [clip_norm]) as name:
Expand Down