diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index c54cbcc4185..205fc5c1752 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -128,6 +128,8 @@ tfjs_cc_library( ":Conv2D", ":DepthwiseConv2dNative", ":FloorDiv", + ":Minimum", + ":Maximum", ":FusedConv2D", ":FusedDepthwiseConv2D", ":Div", @@ -289,6 +291,26 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Minimum", + srcs = ["kernels/Minimum.cc"], + deps = [ + ":backend", + ":binary", + ":util", + ], +) + +tfjs_cc_library( + name = "Maximum", + srcs = ["kernels/Maximum.cc"], + deps = [ + ":backend", + ":binary", + ":util", + ], +) + tfjs_cc_library( name = "FusedConv2D", srcs = ["kernels/FusedConv2D.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Maximum.cc b/tfjs-backend-wasm/src/cc/kernels/Maximum.cc new file mode 100644 index 00000000000..883531414bc --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Maximum.cc @@ -0,0 +1,53 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Maximum(const int a_id, const size_t* a_shape_ptr, const int a_shape_len, + const int b_id, const size_t* b_shape_ptr, const int b_shape_len, + const DType dtype, const int out_id) { + switch (dtype) { + case DType::float32: + binary_f32(a_id, b_id, out_id, + [](float a, float b) { return std::max(a, b); }); + break; + case DType::int32: + binary_i32(a_id, b_id, out_id, + [](int a, int b) { return std::max(a, b); }); + break; + default: + util::warn( + "Maximum for tensor ids %d and %d failed. Unsupported dtype %d", + a_id, b_id, dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/Minimum.cc b/tfjs-backend-wasm/src/cc/kernels/Minimum.cc new file mode 100644 index 00000000000..4c418290488 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Minimum.cc @@ -0,0 +1,53 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Minimum(const int a_id, const size_t* a_shape_ptr, const int a_shape_len, + const int b_id, const size_t* b_shape_ptr, const int b_shape_len, + const DType dtype, const int out_id) { + switch (dtype) { + case DType::float32: + binary_f32(a_id, b_id, out_id, + [](float a, float b) { return std::min(a, b); }); + break; + case DType::int32: + binary_i32(a_id, b_id, out_id, + [](int a, int b) { return std::min(a, b); }); + break; + default: + util::warn( + "Minimum for tensor ids %d and %d failed. Unsupported dtype %d", + a_id, b_id, dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Maximum.ts b/tfjs-backend-wasm/src/kernels/Maximum.ts new file mode 100644 index 00000000000..ca2216cc827 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Maximum.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import { registerBinaryKernel } from './binary_kernel'; +const supportsBroadcast = false; +registerBinaryKernel('Maximum', supportsBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/Minimum.ts b/tfjs-backend-wasm/src/kernels/Minimum.ts new file mode 100644 index 00000000000..dede1c4fafe --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Minimum.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import { registerBinaryKernel } from './binary_kernel'; +const supportsBroadcast = false; +registerBinaryKernel('Minimum', supportsBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 3360c73762b..9bf8e7a36ae 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -37,8 +37,10 @@ import './FusedBatchNorm'; import './FusedConv2D'; import './FusedDepthwiseConv2D'; import './Max'; +import './Maximum'; import './MaxPool'; import './Min'; +import './Minimum'; import './Mul'; import './NonMaxSuppressionV3'; import './PadV2'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 514deefd5e2..2d922483d16 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -195,6 +195,22 @@ const TEST_FILTERS: TestFilter[] = [ {include: 'argmax', excludes: ['gradient']}, {include: 'exp '}, {include: 'unstack'}, + { + include: 'minimum', + excludes: [ + 'gradient', // Not yet implemented. + 'broadcasts 2x1 Tensor2D and 2x2 Tensor2D' // Broadcasting along inner + // dims not supported yet. + ] + }, + { + include: 'maximum', + excludes: [ + 'gradient', // Not yet implemented. + 'broadcasts 2x1 Tensor2D and 2x2 Tensor2D' // Broadcasting along inner + // dims not supported yet. + ] + }, ]; const customInclude = (testName: string) => { diff --git a/tfjs-core/src/ops/binary_ops.ts b/tfjs-core/src/ops/binary_ops.ts index 67f48fe4d4d..e6d7e241429 100644 --- a/tfjs-core/src/ops/binary_ops.ts +++ b/tfjs-core/src/ops/binary_ops.ts @@ -673,13 +673,13 @@ function minimum_( const [$a, $b] = saved; const derA = () => dy.mul($a.lessEqual($b).toFloat()); const derB = () => dy.mul($a.greater($b).toFloat()); - return {$a: derA, $b: derB}; + return {a: derA, b: derB}; }; return ENGINE.runKernelFunc((backend, save) => { const res = backend.minimum($a, $b); save([$a, $b]); return res; - }, {$a, $b}, der) as T; + }, {a: $a, b: $b}, der, 'Minimum') as T; } /** @@ -738,13 +738,13 @@ function maximum_( const [$a, $b] = saved; const derA = () => dy.mul($a.greaterEqual($b).toFloat()); const derB = () => dy.mul($a.less($b).toFloat()); - return {$a: derA, $b: derB}; + return {a: derA, b: derB}; }; return ENGINE.runKernelFunc((backend, save) => { const res = backend.maximum($a, $b); save([$a, $b]); return res; - }, {$a, $b}, der) as T; + }, {a: $a, b: $b}, der, 'Maximum') as T; } /**