From 189ee595efb940cf40807fbab5aea8224cc296ef Mon Sep 17 00:00:00 2001 From: sapphi-red Date: Tue, 10 Dec 2019 09:48:39 +0900 Subject: [PATCH 1/2] [wasm] Add `Minimum` --- tfjs-backend-wasm/src/cc/BUILD | 11 ++++ tfjs-backend-wasm/src/cc/kernels/Minimum.cc | 53 ++++++++++++++++++++ tfjs-backend-wasm/src/kernels/Minimum.ts | 20 ++++++++ tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + tfjs-backend-wasm/src/setup_test.ts | 8 +++ tfjs-core/src/ops/binary_ops.ts | 4 +- 6 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 tfjs-backend-wasm/src/cc/kernels/Minimum.cc create mode 100644 tfjs-backend-wasm/src/kernels/Minimum.ts diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index c54cbcc4185..3993b5db5d8 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -128,6 +128,7 @@ tfjs_cc_library( ":Conv2D", ":DepthwiseConv2dNative", ":FloorDiv", + ":Minimum", ":FusedConv2D", ":FusedDepthwiseConv2D", ":Div", @@ -289,6 +290,16 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Minimum", + srcs = ["kernels/Minimum.cc"], + deps = [ + ":backend", + ":binary", + ":util", + ], +) + tfjs_cc_library( name = "FusedConv2D", srcs = ["kernels/FusedConv2D.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Minimum.cc b/tfjs-backend-wasm/src/cc/kernels/Minimum.cc new file mode 100644 index 00000000000..4c418290488 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Minimum.cc @@ -0,0 +1,53 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Minimum(const int a_id, const size_t* a_shape_ptr, const int a_shape_len, + const int b_id, const size_t* b_shape_ptr, const int b_shape_len, + const DType dtype, const int out_id) { + switch (dtype) { + case DType::float32: + binary_f32(a_id, b_id, out_id, + [](float a, float b) { return std::min(a, b); }); + break; + case DType::int32: + binary_i32(a_id, b_id, out_id, + [](int a, int b) { return std::min(a, b); }); + break; + default: + util::warn( + "Minimum for tensor ids %d and %d failed. Unsupported dtype %d", + a_id, b_id, dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Minimum.ts b/tfjs-backend-wasm/src/kernels/Minimum.ts new file mode 100644 index 00000000000..dede1c4fafe --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Minimum.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import { registerBinaryKernel } from './binary_kernel'; +const supportsBroadcast = false; +registerBinaryKernel('Minimum', supportsBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 3360c73762b..722a83722e0 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -39,6 +39,7 @@ import './FusedDepthwiseConv2D'; import './Max'; import './MaxPool'; import './Min'; +import './Minimum'; import './Mul'; import './NonMaxSuppressionV3'; import './PadV2'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 8a4da9c7301..e7538da8ee5 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -195,6 +195,14 @@ const TEST_FILTERS: TestFilter[] = [ {include: 'argmax', excludes: ['gradient']}, {include: 'exp '}, {include: 'unstack'}, + { + include: 'minimum', + excludes: [ + 'gradient', // Not yet implemented. + 'broadcasts 2x1 Tensor2D and 2x2 Tensor2D' // Broadcasting along inner + // dims not supported yet. + ] + }, ]; const customInclude = (testName: string) => { diff --git a/tfjs-core/src/ops/binary_ops.ts b/tfjs-core/src/ops/binary_ops.ts index 67f48fe4d4d..9e4d28f8e48 100644 --- a/tfjs-core/src/ops/binary_ops.ts +++ b/tfjs-core/src/ops/binary_ops.ts @@ -673,13 +673,13 @@ function minimum_( const [$a, $b] = saved; const derA = () => dy.mul($a.lessEqual($b).toFloat()); const derB = () => dy.mul($a.greater($b).toFloat()); - return {$a: derA, $b: derB}; + return {a: derA, b: derB}; }; return ENGINE.runKernelFunc((backend, save) => { const res = backend.minimum($a, $b); save([$a, $b]); return res; - }, {$a, $b}, der) as T; + }, {a: $a, b: $b}, der, 'Minimum') as T; } /** From 625db8c14aba996ee8b9e70332f2866468b03d8b Mon Sep 17 00:00:00 2001 From: sapphi-red Date: Wed, 11 Dec 2019 10:27:51 +0900 Subject: [PATCH 2/2] [wasm] Add `Maximum` --- tfjs-backend-wasm/src/cc/BUILD | 11 ++++ tfjs-backend-wasm/src/cc/kernels/Maximum.cc | 53 ++++++++++++++++++++ tfjs-backend-wasm/src/kernels/Maximum.ts | 20 ++++++++ tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + tfjs-backend-wasm/src/setup_test.ts | 8 +++ tfjs-core/src/ops/binary_ops.ts | 4 +- 6 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 tfjs-backend-wasm/src/cc/kernels/Maximum.cc create mode 100644 tfjs-backend-wasm/src/kernels/Maximum.ts diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 3993b5db5d8..205fc5c1752 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -129,6 +129,7 @@ tfjs_cc_library( ":DepthwiseConv2dNative", ":FloorDiv", ":Minimum", + ":Maximum", ":FusedConv2D", ":FusedDepthwiseConv2D", ":Div", @@ -300,6 +301,16 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Maximum", + srcs = ["kernels/Maximum.cc"], + deps = [ + ":backend", + ":binary", + ":util", + ], +) + tfjs_cc_library( name = "FusedConv2D", srcs = ["kernels/FusedConv2D.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Maximum.cc b/tfjs-backend-wasm/src/cc/kernels/Maximum.cc new file mode 100644 index 00000000000..883531414bc --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Maximum.cc @@ -0,0 +1,53 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Maximum(const int a_id, const size_t* a_shape_ptr, const int a_shape_len, + const int b_id, const size_t* b_shape_ptr, const int b_shape_len, + const DType dtype, const int out_id) { + switch (dtype) { + case DType::float32: + binary_f32(a_id, b_id, out_id, + [](float a, float b) { return std::max(a, b); }); + break; + case DType::int32: + binary_i32(a_id, b_id, out_id, + [](int a, int b) { return std::max(a, b); }); + break; + default: + util::warn( + "Maximum for tensor ids %d and %d failed. Unsupported dtype %d", + a_id, b_id, dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Maximum.ts b/tfjs-backend-wasm/src/kernels/Maximum.ts new file mode 100644 index 00000000000..ca2216cc827 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Maximum.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import { registerBinaryKernel } from './binary_kernel'; +const supportsBroadcast = false; +registerBinaryKernel('Maximum', supportsBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 722a83722e0..9bf8e7a36ae 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -37,6 +37,7 @@ import './FusedBatchNorm'; import './FusedConv2D'; import './FusedDepthwiseConv2D'; import './Max'; +import './Maximum'; import './MaxPool'; import './Min'; import './Minimum'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index e7538da8ee5..3894458f250 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -203,6 +203,14 @@ const TEST_FILTERS: TestFilter[] = [ // dims not supported yet. ] }, + { + include: 'maximum', + excludes: [ + 'gradient', // Not yet implemented. + 'broadcasts 2x1 Tensor2D and 2x2 Tensor2D' // Broadcasting along inner + // dims not supported yet. + ] + }, ]; const customInclude = (testName: string) => { diff --git a/tfjs-core/src/ops/binary_ops.ts b/tfjs-core/src/ops/binary_ops.ts index 9e4d28f8e48..e6d7e241429 100644 --- a/tfjs-core/src/ops/binary_ops.ts +++ b/tfjs-core/src/ops/binary_ops.ts @@ -738,13 +738,13 @@ function maximum_( const [$a, $b] = saved; const derA = () => dy.mul($a.greaterEqual($b).toFloat()); const derB = () => dy.mul($a.less($b).toFloat()); - return {$a: derA, $b: derB}; + return {a: derA, b: derB}; }; return ENGINE.runKernelFunc((backend, save) => { const res = backend.maximum($a, $b); save([$a, $b]); return res; - }, {$a, $b}, der) as T; + }, {a: $a, b: $b}, der, 'Maximum') as T; } /**