diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 14266d61d56..b4d1fce0d94 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -65,6 +65,8 @@ tfjs_cc_library( ":Mul", ":Prelu", ":FusedBatchNorm", + ":Max", + ":Min", ":Sigmoid", ":Sub", ] @@ -79,6 +81,24 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Max", + srcs = ["kernels/Max.cc"], + deps = [ + ":backend", + ":util", + ], +) + +tfjs_cc_library( + name = "Min", + srcs = ["kernels/Min.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "Abs", srcs = ["kernels/Abs.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Max.cc b/tfjs-backend-wasm/src/cc/kernels/Max.cc new file mode 100644 index 00000000000..2df12cd20b4 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Max.cc @@ -0,0 +1,62 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/backend.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Max(int x_id, int reduce_size, int out_id) { + const auto x_info = backend::get_tensor_info(x_id); + const auto out_info = backend::get_tensor_info(out_id); + + float* x_buf = x_info.buf.f32; + int x_size = x_info.size; + + float* out_buf = out_info.buf.f32; + int out_size = out_info.size; + + float* x_offset = x_info.buf.f32; + + for (int i = 0; i < out_size; ++i) { + int offset = i * reduce_size; + float max = x_buf[offset]; + + float* x_iter_end = x_offset + reduce_size; + + for (float* x = x_offset; x < x_iter_end; ++x) { + float value = *x; + if (value > max) { + max = value; + } + } + + x_offset += reduce_size; + + out_buf[i] = max; + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/Min.cc b/tfjs-backend-wasm/src/cc/kernels/Min.cc new file mode 100644 index 00000000000..e9217c117aa --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Min.cc @@ -0,0 +1,62 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/backend.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Min(int x_id, int reduce_size, int out_id) { + const auto x_info = backend::get_tensor_info(x_id); + const auto out_info = backend::get_tensor_info(out_id); + + float* x_buf = x_info.buf.f32; + int x_size = x_info.size; + + float* out_buf = out_info.buf.f32; + int out_size = out_info.size; + + float* x_offset = x_info.buf.f32; + + for (int i = 0; i < out_size; ++i) { + int offset = i * reduce_size; + float min = x_buf[offset]; + + float* x_iter_end = x_offset + reduce_size; + + for (float* x = x_offset; x < x_iter_end; ++x) { + float value = *x; + if (value < min) { + min = value; + } + } + + x_offset += reduce_size; + + out_buf[i] = min; + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Max.ts b/tfjs-backend-wasm/src/kernels/Max.ts new file mode 100644 index 00000000000..ab688b76c80 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Max.ts @@ -0,0 +1,65 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface MaxInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +interface MaxAttrs extends NamedAttrMap { + axes: number[]; +} + +let wasmMax: (xId: number, reduceSize: number, outId: number) => void; + +function setup(backend: BackendWasm): void { + wasmMax = + backend.wasm.cwrap('Max', null /*void*/, ['number, number, number']); +} + +function max(args: {backend: BackendWasm, inputs: MaxInputs, attrs: MaxAttrs}): + TensorInfo { + const {backend, inputs, attrs} = args; + const {axes} = attrs; + const {x} = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + + backend_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + const [outShape, reduceShape] = + backend_util.computeOutAndReduceShapes(x.shape, axes); + const reduceSize = util.sizeFromShape(reduceShape); + + const out = backend.makeOutput(outShape, x.dtype); + if (util.sizeFromShape(x.shape) === 0) { + return out; + } + + const outId = backend.dataIdMap.get(out.dataId).id; + + wasmMax(xId, reduceSize, outId); + return out; +} + +registerKernel({ + kernelName: 'Max', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: max +}); diff --git a/tfjs-backend-wasm/src/kernels/Min.ts b/tfjs-backend-wasm/src/kernels/Min.ts new file mode 100644 index 00000000000..c2c8c61d202 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Min.ts @@ -0,0 +1,65 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface MinInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +interface MinAttrs extends NamedAttrMap { + axes: number[]; +} + +let wasmMin: (xId: number, reduceSize: number, outId: number) => void; + +function setup(backend: BackendWasm): void { + wasmMin = + backend.wasm.cwrap('Min', null /*void*/, ['number, number, number']); +} + +function min(args: {backend: BackendWasm, inputs: MinInputs, attrs: MinAttrs}): + TensorInfo { + const {backend, inputs, attrs} = args; + const {axes} = attrs; + const {x} = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + + backend_util.assertAxesAreInnerMostDims('min', axes, x.shape.length); + const [outShape, reduceShape] = + backend_util.computeOutAndReduceShapes(x.shape, axes); + const reduceSize = util.sizeFromShape(reduceShape); + + const out = backend.makeOutput(outShape, x.dtype); + if (util.sizeFromShape(x.shape) === 0) { + return out; + } + + const outId = backend.dataIdMap.get(out.dataId).id; + + wasmMin(xId, reduceSize, outId); + return out; +} + +registerKernel({ + kernelName: 'Min', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: min +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 193d3ccfefe..639277977db 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -25,6 +25,8 @@ import './FusedBatchNorm'; import './Cast'; import './Div'; import './Mul'; +import './Min'; +import './Max'; import './Prelu'; import './Reshape'; import './Sigmoid'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 2acd1e46de8..e6d7d2ec301 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -27,7 +27,7 @@ const grepFilter = env.specFilter; /** Tests that have these substrings in their name will be included. */ const INCLUDE_LIST: string[] = [ 'add ', 'matmul ', 'prelu ', ' cast', 'sigmoid', 'abs ', 'sub ', 'mul ', - 'div ', 'batchNorm', 'slice ', 'square ' + 'div ', 'batchNorm', 'slice ', 'square ', ' min ', ' max ' ]; /** Tests that have these substrings in their name will be excluded. */ const EXCLUDE_LIST: string[] = [ @@ -58,7 +58,21 @@ const EXCLUDE_LIST: string[] = [ // Mul 'broadcast 5D + 2D', // Broadcasting along inner dims not supported yet. - 'broadcast 6D + 2D' // Broadcasting along inner dims not supported yet. + 'broadcast 6D + 2D', // Broadcasting along inner dims not supported yet. + + // max + 'max x=[', // Pool not yet implemented. + 'max index corresponds to start of a non-initial window', // argMax not yet + // implemented. + + // min + 'min index corresponds to start of a non-initial window', // argMin not yet + // implemented. + + // min and max + 'derivative: 1D tensor with max or min value', // Clip not yet implemented. + '2D, axis=0', // Permuted axes requires transpose, which is not yet + // implemented. ]; /** diff --git a/tfjs-core/src/ops/reduction_ops.ts b/tfjs-core/src/ops/reduction_ops.ts index 38766f7646b..0d3edb1ebb8 100644 --- a/tfjs-core/src/ops/reduction_ops.ts +++ b/tfjs-core/src/ops/reduction_ops.ts @@ -270,7 +270,7 @@ function gradForMinAndMax( dy = dy.reshape(axis_util.expandShapeToKeepDim(dy.shape, origAxes)) as T; } return { - $x: () => { + x: () => { const dx = dy.mul(xOrig.equal(y).cast(dy.dtype)); return permutedAxes == null ? dx : dx.transpose(permutedAxes); } @@ -320,11 +320,14 @@ function min_( const grad = (dy: T, saved: Tensor[]) => gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes); + + const inputsToSave = [$x]; + const outputsToSave: boolean[] = [true]; let res = ENGINE.runKernelFunc((backend, save) => { const y = backend.min($x, axes); save([xOrig, y]); return y as T; - }, {$x}, grad); + }, {x: $x}, grad, 'Min', {axes}, inputsToSave, outputsToSave); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); res = res.reshape(newShape) as T; @@ -375,11 +378,14 @@ function max_( const grad = (dy: T, saved: Tensor[]) => gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes); + + const inputsToSave = [$x]; + const outputsToSave: boolean[] = [true]; let res = ENGINE.runKernelFunc((backend, save) => { const y = backend.max($x, axes); save([xOrig, y]); return y; - }, {$x}, grad); + }, {x: $x}, grad, 'Max', {axes}, inputsToSave, outputsToSave); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); res = res.reshape(newShape) as T;