Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tf.linalg.normalize #29691

Merged
merged 14 commits into from Jun 24, 2019
19 changes: 19 additions & 0 deletions tensorflow/python/kernel_tests/BUILD
Expand Up @@ -3506,6 +3506,25 @@ cuda_py_test(
xla_enable_strict_auto_jit = True,
)

cuda_py_test(
name = "normalize_op_test",
size = "medium",
srcs = ["normalize_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn",
],
shard_count = 20,
# TODO(b/117236102): Re-enable in msan build.
tags = [
"no_windows_gpu",
"nomsan",
],
xla_enable_strict_auto_jit = True,
)

cuda_py_test(
name = "tensordot_op_test",
size = "medium",
Expand Down
98 changes: 98 additions & 0 deletions tensorflow/python/kernel_tests/normalize_op_test.py
@@ -0,0 +1,98 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tf.norm."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.framework import test_util
from tensorflow.python.ops import nn_impl
from tensorflow.python.platform import test as test_lib


def _AddTest(test, test_name, fn):
test_name = "_".join(["test", test_name])
if hasattr(test, test_name):
raise RuntimeError("Test %s defined more than once" % test_name)
setattr(test, test_name, fn)


# pylint: disable=redefined-builtin
def _Normalize(x, ord, axis):
if isinstance(axis, (list, tuple)):
norm = np.linalg.norm(x, ord, tuple(axis))
if axis[0] < axis[1]:
# This prevents axis to be inserted in-between
# e.g. when (-2, -1)
for d in reversed(axis):
norm = np.expand_dims(norm, d)
else:
for d in axis:
norm = np.expand_dims(norm, d)
return x / norm
elif axis is None:
# Tensorflow handles None differently
norm = np.linalg.norm(x.flatten(), ord, axis)
return x / norm
else:
norm = np.apply_along_axis(np.linalg.norm, axis, x, ord)
return x / np.expand_dims(norm, axis)


class NormalizeOpTest(test_lib.TestCase):
pass


def _GetNormalizeOpTest(dtype_, shape_, ord_, axis_):

@test_util.run_in_graph_and_eager_modes
def Test(self):
is_matrix_norm = (isinstance(axis_, tuple) or
isinstance(axis_, list)) and len(axis_) == 2
is_fancy_p_norm = np.isreal(ord_) and np.floor(ord_) != ord_
if ((not is_matrix_norm and ord_ == "fro") or
(is_matrix_norm and is_fancy_p_norm)):
self.skipTest("Not supported by neither numpy.linalg.norm nor tf.norm")
if ord_ == "euclidean" or (axis_ is None and len(shape) > 2):
self.skipTest("Not supported by numpy.linalg.norm")
matrix = np.random.randn(*shape_).astype(dtype_)
if dtype_ in (np.complex64, np.complex128):
matrix += 1j * np.random.randn(*shape_).astype(dtype_)
tf_np_n, _ = self.evaluate(nn_impl.normalize(matrix, ord_, axis_))
np_n = _Normalize(matrix, ord_, axis_)
self.assertAllClose(tf_np_n, np_n, rtol=1e-5, atol=1e-5)

return Test

# pylint: disable=redefined-builtin
if __name__ == "__main__":
for dtype in np.float32, np.float64, np.complex64, np.complex128:
for rows in 2, 5:
for cols in 2, 5:
for batch in [], [2], [2, 3]:
shape = batch + [rows, cols]
for ord in "euclidean", "fro", 0.5, 1, 2, np.inf:
for axis in [
None, (-2, -1), (-1, -2), -len(shape), 0, len(shape) - 1
]:
name = "%s_%s_ord_%s_axis_%s" % (
dtype.__name__, "_".join(map(str, shape)), ord, axis)
_AddTest(NormalizeOpTest, "Normalize_" + name,
_GetNormalizeOpTest(dtype, shape, ord, axis))

test_lib.main()
60 changes: 60 additions & 0 deletions tensorflow/python/ops/nn_impl.py
Expand Up @@ -31,6 +31,7 @@
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gen_array_ops # pylint: disable=unused-import
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import gen_sparse_ops
Expand Down Expand Up @@ -435,6 +436,65 @@ def swish(features):
return features * math_ops.sigmoid(features)


# pylint: disable=redefined-builtin
@tf_export("linalg.normalize")
def normalize(tensor,
ord='euclidean',
axis=None,
name=None):
"""Normalizes `tensor` along dimension `axis` using specified norm.

This uses `tf.linalg.norm` to compute the norm along `axis`.

This function can compute several different vector norms (the 1-norm, the
Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).

Args:
tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
ord: Order of the norm. Supported values are `'fro'`, `'euclidean'`,
`1`, `2`, `np.inf` and any positive real number yielding the corresponding
p-norm. Default is `'euclidean'` which is equivalent to Frobenius norm if
`tensor` is a matrix and equivalent to 2-norm for vectors.
Some restrictions apply:
a) The Frobenius norm `'fro'` is not defined for vectors,
b) If axis is a 2-tuple (matrix norm), only `'euclidean'`, '`fro'`, `1`,
`2`, `np.inf` are supported.
See the description of `axis` on how to compute norms for a batch of
vectors or matrices stored in a tensor.
axis: If `axis` is `None` (the default), the input is considered a vector
and a single vector norm is computed over the entire set of values in the
tensor, i.e. `norm(tensor, ord=ord)` is equivalent to
`norm(reshape(tensor, [-1]), ord=ord)`.
If `axis` is a Python integer, the input is considered a batch of vectors,
and `axis` determines the axis in `tensor` over which to compute vector
norms.
If `axis` is a 2-tuple of Python integers it is considered a batch of
matrices and `axis` determines the axes in `tensor` over which to compute
a matrix norm.
Negative indices are supported. Example: If you are passing a tensor that
can be either a matrix or a batch of matrices at runtime, pass
`axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are
computed.
name: The name of the op.

Returns:
normalized: A normalized `Tensor` with the same shape as `tensor`.
norm: The computed norms with the same shape and dtype `tensor` but the
final axis is 1 instead. Same as running
`tf.cast(tf.linalg.norm(tensor, ord, axis keepdims=True), tensor.dtype)`.

Raises:
ValueError: If `ord` or `axis` is invalid.
"""
with ops.name_scope(name, "normalize", [tensor]) as name:
tensor = ops.convert_to_tensor(tensor)
norm = linalg_ops.norm(tensor, ord, axis, keepdims=True)
norm = math_ops.cast(norm, tensor.dtype)
normalized = tensor / norm
return normalized, norm


@tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
Expand Up @@ -168,6 +168,10 @@ tf_module {
name: "norm"
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "normalize"
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\'], "
}
member_method {
name: "qr"
argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
Expand Up @@ -168,6 +168,10 @@ tf_module {
name: "norm"
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "normalize"
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\'], "
}
member_method {
name: "qr"
argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
Expand Down