Skip to content

Commit

Permalink
Add bfloat16 support for comparison CPU ops.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 201557049
  • Loading branch information
tensorflower-gardener committed Jun 21, 2018
1 parent 86fb0cd commit 8fd7142
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 13 deletions.
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/cwise_op_equal_to_1.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"

namespace tensorflow {
REGISTER6(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
uint8, int8, int16);
REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
uint8, int8, int16, bfloat16);
REGISTER_KERNEL_BUILDER(
Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<float>("T"),
ApproximateEqualOp<CPUDevice, float>);
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/cwise_op_greater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"

namespace tensorflow {
REGISTER8(BinaryOp, CPU, "Greater", functor::greater, float, Eigen::half,
double, int32, int64, uint8, int8, int16);
REGISTER9(BinaryOp, CPU, "Greater", functor::greater, float, Eigen::half,
double, int32, int64, uint8, int8, int16, bfloat16);
#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "Greater", functor::greater, float, Eigen::half,
double, int64, uint8, int8, int16);
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/cwise_op_greater_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"

namespace tensorflow {
REGISTER8(BinaryOp, CPU, "GreaterEqual", functor::greater_equal, float,
Eigen::half, double, int32, int64, uint8, int8, int16);
REGISTER9(BinaryOp, CPU, "GreaterEqual", functor::greater_equal, float,
Eigen::half, double, int32, int64, uint8, int8, int16, bfloat16);
#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "GreaterEqual", functor::greater_equal, float,
Eigen::half, double, int64, uint8, int8, int16);
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/core/kernels/cwise_op_less.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"

namespace tensorflow {
REGISTER9(BinaryOp, CPU, "Less", functor::less, float, Eigen::half, double,
bfloat16, int32, int64, uint8, int8, int16);
REGISTER5(BinaryOp, CPU, "Less", functor::less, float, Eigen::half, double,
bfloat16, int32);
REGISTER5(BinaryOp, CPU, "Less", functor::less, int64, uint8, int8, int16,
bfloat16);

#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "Less", functor::less, float, Eigen::half, double,
int64, uint8, int8, int16);
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/core/kernels/cwise_op_less_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"

namespace tensorflow {
REGISTER9(BinaryOp, CPU, "LessEqual", functor::less_equal, float, Eigen::half,
bfloat16, double, int32, int64, uint8, int8, int16);
REGISTER5(BinaryOp, CPU, "LessEqual", functor::less_equal, float, Eigen::half,
bfloat16, double, int32);
REGISTER5(BinaryOp, CPU, "LessEqual", functor::less_equal, int64, uint8, int8,
int16, bfloat16);

#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "LessEqual", functor::less_equal, float, Eigen::half,
double, int64, uint8, int8, int16);
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"

namespace tensorflow {
REGISTER6(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
double, uint8, int8, int16);
REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
double, uint8, int8, int16, bfloat16);
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
double, uint8);
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/python/kernel_tests/cwise_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def _compareCpu(self, x, np_func, tf_func, grad_rtol=None, grad_atol=None):
np_ans = np_func(x)
with self.test_session(use_gpu=False):
inx = ops.convert_to_tensor(x)
if x.dtype in (np.float32, np.float64):
if x.dtype in (np.float32, np.float64,
dtypes_lib.bfloat16.as_numpy_dtype):
y = 1.1 * tf_func(inx)
np_ans *= 1.1
else:
Expand All @@ -105,6 +106,8 @@ def _compareCpu(self, x, np_func, tf_func, grad_rtol=None, grad_atol=None):
self.assertShapeEqual(np_ans, y)
if x.dtype == np.float16:
self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3)
elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype:
self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2)
else:
self.assertAllClose(np_ans, tf_cpu)

Expand Down

0 comments on commit 8fd7142

Please sign in to comment.