From 7e6129061b991f8642855a1bc2f91bdd77031427 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Mon, 15 Jun 2020 01:33:54 -0700 Subject: [PATCH] [WASM] Add equal op Also enabled some notEqual tests. --- tfjs-backend-wasm/src/cc/BUILD | 10 ++++ tfjs-backend-wasm/src/cc/kernels/Equal.cc | 61 ++++++++++++++++++++ tfjs-backend-wasm/src/kernels/Equal.ts | 20 +++++++ tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + tfjs-backend-wasm/src/setup_test.ts | 18 ++++++ tfjs-core/src/ops/equal_test.ts | 2 +- tfjs-core/src/ops/not_equal_test.ts | 2 +- 7 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 tfjs-backend-wasm/src/cc/kernels/Equal.cc create mode 100644 tfjs-backend-wasm/src/kernels/Equal.ts diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 1e3ee161e6f..71a1bf001d2 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -194,6 +194,7 @@ tfjs_cc_library( ":CropAndResize", ":DepthwiseConv2dNative", ":Div", + ":Equal", ":Exp", ":FloorDiv", ":FusedBatchNorm", @@ -398,6 +399,15 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Equal", + srcs = ["kernels/Equal.cc"], + deps = [ + ":binary", + ":util", + ], +) + tfjs_cc_library( name = "Exp", srcs = ["kernels/Exp.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Equal.cc b/tfjs-backend-wasm/src/cc/kernels/Equal.cc new file mode 100644 index 00000000000..6f16b3514e0 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Equal.cc @@ -0,0 +1,61 @@ +/* Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace { +template +inline bool equal(T a, T b) { + return a == b; +} +} // namespace + +namespace tfjs { +namespace wasm { +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Equal(const size_t a_id, const size_t* a_shape_ptr, + const size_t a_shape_len, const size_t b_id, + const size_t* b_shape_ptr, const size_t b_shape_len, + const DType input_type, const size_t out_id) { + switch (input_type) { + case DType::float32: + compare_f32(a_id, b_id, out_id, equal); + break; + case DType::int32: + compare_i32(a_id, b_id, out_id, equal); + break; + case DType::boolean: + compare_bool(a_id, b_id, out_id, equal); + break; + default: + util::warn( + "Equal for tensor ids %d and %d failed. Unsupported input_type %d", + a_id, b_id, input_type); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Equal.ts b/tfjs-backend-wasm/src/kernels/Equal.ts new file mode 100644 index 00000000000..f216805c1d4 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Equal.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerBinaryKernel} from './binary_kernel'; +const supportsFullBroadcast = false; +registerBinaryKernel('Equal', supportsFullBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 9609def25f9..e03b9a2f3e5 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -33,6 +33,7 @@ import './Cos'; import './CropAndResize'; import './DepthwiseConv2dNative'; import './Div'; +import './Equal'; import './Exp'; import './Fill'; import './FloorDiv'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index ca8a36353bf..8336a0eac83 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -249,6 +249,15 @@ const TEST_FILTERS: TestFilter[] = [ { include: 'log ', }, + { + startsWith: 'equal ', + excludes: [ + 'broadcasting Tensor2D shapes', // Broadcasting along outer dims not + // supported yet. + 'broadcasting Tensor3D shapes', // Same as above. + 'broadcasting Tensor4D shapes' // Same as above. + ] + }, { include: 'greater ', excludes: [ @@ -289,6 +298,15 @@ const TEST_FILTERS: TestFilter[] = [ 'broadcasting Tensor4D shapes' // Same as above. ] }, + { + include: 'notEqual', + excludes: [ + 'broadcasting Tensor2D shapes', // Broadcasting along outer dims not + // supported yet. + 'broadcasting Tensor3D shapes', // Same as above. + 'broadcasting Tensor4D shapes' // Same as above. + ] + }, { include: 'mean ', excludes: [ diff --git a/tfjs-core/src/ops/equal_test.ts b/tfjs-core/src/ops/equal_test.ts index 7ffac5d2396..fab5a4c4aad 100644 --- a/tfjs-core/src/ops/equal_test.ts +++ b/tfjs-core/src/ops/equal_test.ts @@ -143,7 +143,7 @@ describeWithFlags('equal', ALL_ENVS, () => { const b = tf.tensor2d([[0.1, NaN], [1.1, NaN]], [2, 2], 'float32'); expectArraysClose(await tf.equal(a, b).data(), [0, 0, 1, 0]); }); - it('2D and 2D broadcast each with 1 dim', async () => { + it('broadcasting Tensor2D shapes each with 1 dim', async () => { const a = tf.tensor2d([1, 2, 5], [1, 3]); const b = tf.tensor2d([5, 1], [2, 1]); const res = tf.equal(a, b); diff --git a/tfjs-core/src/ops/not_equal_test.ts b/tfjs-core/src/ops/not_equal_test.ts index 94504519c8b..894438a3cc1 100644 --- a/tfjs-core/src/ops/not_equal_test.ts +++ b/tfjs-core/src/ops/not_equal_test.ts @@ -158,7 +158,7 @@ describeWithFlags('notEqual', ALL_ENVS, () => { expect(res.shape).toEqual([2, 3]); expectArraysEqual(await res.data(), [1, 0, 1, 0, 1, 1]); }); - it('2D and 2D broadcast each with 1 dim', async () => { + it('broadcasting Tensor2D shapes each with 1 dim', async () => { const a = tf.tensor2d([1, 2, 5], [1, 3]); const b = tf.tensor2d([5, 1], [2, 1]); const res = tf.notEqual(a, b);