Skip to content

Commit

Permalink
Enable reciprocal GPU kernel on complex number
Browse files Browse the repository at this point in the history
Add complex reciprocal GPU test

Fix register number
  • Loading branch information
WindQAQ committed Jul 20, 2020
1 parent 320f227 commit f208126
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License.

namespace tensorflow {
namespace functor {
DEFINE_UNARY4(inverse, Eigen::half, float, double, int64);
DEFINE_UNARY6(inverse, Eigen::half, float, double, int64, complex64, complex128);
DEFINE_SIMPLE_BINARY3(inverse_grad, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/core/kernels/cwise_op_reciprocal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ namespace tensorflow {
REGISTER5(UnaryOp, CPU, "Inv", functor::inverse, float, Eigen::half, double,
complex64, complex128);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER4(UnaryOp, GPU, "Inv", functor::inverse, float, Eigen::half, double,
int64);
REGISTER6(UnaryOp, GPU, "Inv", functor::inverse, float, Eigen::half, double,
int64, complex64, complex128);
#endif

REGISTER5(SimpleBinaryOp, CPU, "InvGrad", functor::inverse_grad, float,
Expand All @@ -36,12 +36,12 @@ REGISTER3(SimpleBinaryOp, GPU, "InvGrad", functor::inverse_grad, float,
REGISTER6(UnaryOp, CPU, "Reciprocal", functor::inverse, float, Eigen::half,
double, complex64, complex128, bfloat16);
#else
REGISTER5(UnaryOp, CPU, "Reciprocal", functor::inverse, float, Eigen::half,
REGISTER6(UnaryOp, CPU, "Reciprocal", functor::inverse, float, Eigen::half,
double, complex64, complex128);
#endif // ENABLE_INTEL_MKL_BFLOAT16
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER4(UnaryOp, GPU, "Reciprocal", functor::inverse, float, Eigen::half,
double, int64);
double, int64, complex64, complex128);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER(UnaryOp, SYCL, "Reciprocal", functor::inverse, float);
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/python/kernel_tests/cwise_ops_unary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def testComplex64Basic(self):
self._compareBoth(x, np.abs, _ABS)
self._compareBoth(x, np.negative, math_ops.negative)
self._compareBoth(x, np.negative, _NEG)
self._compareCpu(y, self._inv, math_ops.reciprocal)
self._compareBoth(y, self._inv, math_ops.reciprocal)
self._compareCpu(x, np.square, math_ops.square)
self._compareCpu(y, np.sqrt, math_ops.sqrt)
self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
Expand Down Expand Up @@ -491,7 +491,7 @@ def testComplex128Basic(self):
self._compareBoth(x, np.abs, _ABS)
self._compareBoth(x, np.negative, math_ops.negative)
self._compareBoth(x, np.negative, _NEG)
self._compareCpu(y, self._inv, math_ops.reciprocal)
self._compareBoth(y, self._inv, math_ops.reciprocal)
self._compareCpu(x, np.square, math_ops.square)
self._compareCpu(y, np.sqrt, math_ops.sqrt)
self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
Expand Down

0 comments on commit f208126

Please sign in to comment.