diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 4a078830e5f..94a90410256 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -79,6 +79,7 @@ tfjs_cc_library( ":Abs", ":Add", ":BatchMatMul", + ":ClipByValue", ":CropAndResize", ":Conv2D", ":DepthwiseConv2dNative", @@ -151,6 +152,16 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "ClipByValue", + hdrs = ["kernels/ClipByValue.h"], + srcs = ["kernels/ClipByValue.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "CropAndResize", srcs = ["kernels/CropAndResize.cc"], @@ -315,3 +326,11 @@ tfjs_unit_test( ":FusedConv2D", ] ) + +tfjs_unit_test( + name = "ClipByValue_test", + srcs = ["kernels/ClipByValue_test.cc"], + deps = [ + ":ClipByValue", + ] +) diff --git a/tfjs-backend-wasm/src/cc/kernels/ClipByValue.cc b/tfjs-backend-wasm/src/cc/kernels/ClipByValue.cc new file mode 100644 index 00000000000..5dbad0f765c --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/ClipByValue.cc @@ -0,0 +1,94 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/kernels/ClipByValue.h" + +#include +#include +#include +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +namespace { +// These float values are keys to creating the clip operator. We use +// std::array instead of a vanilla array as it implements the compare operator +// needed for std::map. +typedef std::array OperatorCacheKey; + +// The operator cache maps the cache key to the xnn_operator_t instantiated for +// this set of arguments to the xnn_operator. +std::map operator_cache; +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void ClipByValue(const int x_id, const float min, const float max, + const int out_id) { + auto& x_info = backend::get_tensor_info(x_id); + auto& out_info = backend::get_tensor_info_out(out_id); + + const float* x_buf = x_info.f32(); + float* out_buf = out_info.f32_write(); + + xnn_operator_t clamp_op = nullptr; + OperatorCacheKey cache_key = {min, max}; + auto operator_cache_idx = operator_cache.find(cache_key); + if (operator_cache_idx == operator_cache.end()) { + const int channels = 1; + const int strides = channels; + const int flags = 0; + xnn_status status = xnn_create_clamp_nc_f32(channels, strides, strides, min, + max, flags, &clamp_op); + if (status != xnn_status_success) { + util::warn( + "XNN status for xnn_create_clamp_nc_f32 is not successful. Got " + "status %d. Use -c dbg to see XNN logs.", + status); + } + operator_cache.emplace(cache_key, clamp_op); + + tfjs::backend::xnn_operator_count++; + } else { + clamp_op = operator_cache_idx->second; + } + + const int batch_size = x_info.size; + xnn_status status = xnn_setup_clamp_nc_f32( + clamp_op, batch_size, x_buf, out_buf, nullptr /* thread pool */); + if (status != xnn_status_success) { + util::warn( + "XNN status for xnn_setup_clamp_nc_f32 is not successful. Got " + "status %d. Use -c dbg to see XNN logs.", + status); + } + + xnn_run_operator(clamp_op, nullptr /* thread pool */); +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/ClipByValue.h b/tfjs-backend-wasm/src/cc/kernels/ClipByValue.h new file mode 100644 index 00000000000..ceaf00cb582 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/ClipByValue.h @@ -0,0 +1,28 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifndef KERNELS_CLIPBYVALUE_H_ +#define KERNELS_CLIPBYVALUE_H_ + +namespace tfjs { +namespace wasm { +extern "C" { + +void ClipByValue(const int x_id, const float min, const float max, + const int out_id); +} // extern "C" +} // namespace wasm +} // namespace tfjs + +#endif // KERNELS_CLIPBYVALUE_H_ diff --git a/tfjs-backend-wasm/src/cc/kernels/ClipByValue_test.cc b/tfjs-backend-wasm/src/cc/kernels/ClipByValue_test.cc new file mode 100644 index 00000000000..9c50eeaa9a6 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/ClipByValue_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#include + +#include "src/cc/backend.h" +#include "src/cc/kernels/ClipByValue.h" + +TEST(ClipByValue, xnn_operator_count) { + tfjs::wasm::init(); + + ASSERT_EQ(0, tfjs::backend::num_tensors()); + + int x0_id = 0; + int x1_id = 1; + int size = 2; + int min = 0; + int max = 1; + float x_values[2] = {1, 2}; + int out_id = 3; + float out_values[2] = {0, 0}; + + tfjs::wasm::register_tensor(x0_id, size, x_values); + tfjs::wasm::register_tensor(x1_id, size, x_values); + tfjs::wasm::register_tensor(out_id, size, out_values); + + ASSERT_EQ(3, tfjs::backend::num_tensors()); + ASSERT_EQ(0, tfjs::backend::xnn_operator_count); + + // One new xnn_operator should be created for the first call to clip. + tfjs::wasm::ClipByValue(x0_id, min, max, out_id); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + // No new xnn_operators should be created for the second call to clip with + // the same min/max. + tfjs::wasm::ClipByValue(x1_id, min, max, out_id); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + // One new xnn_operator should be created for another call to clip with new + // min/max. + tfjs::wasm::ClipByValue(x0_id, min, max + 1, out_id); + ASSERT_EQ(2, tfjs::backend::xnn_operator_count); + + // No new xnn_operators should be created for the next call to prelu with + // the same min/max. + tfjs::wasm::ClipByValue(x1_id, min, max + 1, out_id); + ASSERT_EQ(2, tfjs::backend::xnn_operator_count); + + // Disposing x's should not remove xnn operators. + tfjs::wasm::dispose_data(x0_id); + tfjs::wasm::dispose_data(x1_id); + ASSERT_EQ(2, tfjs::backend::xnn_operator_count); + + tfjs::wasm::dispose(); +} diff --git a/tfjs-backend-wasm/src/kernels/ClipByValue.ts b/tfjs-backend-wasm/src/kernels/ClipByValue.ts new file mode 100644 index 00000000000..2e3f529addb --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/ClipByValue.ts @@ -0,0 +1,80 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel} from '@tensorflow/tfjs-core'; +import {TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface ClipByValueInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +interface ClipByValueAttrs extends NamedAttrMap { + min: number; + max: number; +} + +let wasmClip: (xId: number, min: number, max: number, outId: number) => void; + +function setup(backend: BackendWasm) { + wasmClip = backend.wasm.cwrap('ClipByValue', null /* void */, [ + 'number', // x_id + 'number', // min + 'number', // max + 'number' // out_id + ]); +} + +function clip(args: { + inputs: ClipByValueInputs, + backend: BackendWasm, + attrs: ClipByValueAttrs +}) { + const {inputs, backend, attrs} = args; + const {x} = inputs; + const {min, max} = attrs; + const xId = backend.dataIdMap.get(x.dataId).id; + const out = backend.makeOutput(x.shape, 'float32'); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmClip(xId, min, max, outId); + return out; +} + +registerKernel({ + kernelName: 'ClipByValue', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: clip +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 085baf4af3e..0dd4b6ffe08 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -21,14 +21,19 @@ import './Abs'; import './Add'; import './BatchMatMul'; -import './CropAndResize'; -import './FusedBatchNorm'; import './Cast'; +import './ClipByValue'; import './Concat'; +import './Conv2D'; +import './CropAndResize'; +import './DepthwiseConv2dNative'; import './Div'; -import './Mul'; -import './Min'; +import './FusedBatchNorm'; +import './FusedConv2D'; import './Max'; +import './Min'; +import './Mul'; +import './PadV2'; import './Prelu'; import './Reshape'; import './Sigmoid'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index e00e6b95f58..bb7f7db7183 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -172,6 +172,7 @@ const TEST_FILTERS: TestFilter[] = [ excludes: ['oneHot'] // oneHot not yet implemented. }, {include: 'pad ', excludes: ['complex', 'zerosLike']}, + {include: 'clip', excludes: ['gradient']}, ]; const customInclude = (testName: string) => { diff --git a/tfjs-core/src/ops/unary_ops.ts b/tfjs-core/src/ops/unary_ops.ts index 69527bbb786..55977225d40 100644 --- a/tfjs-core/src/ops/unary_ops.ts +++ b/tfjs-core/src/ops/unary_ops.ts @@ -427,17 +427,19 @@ function clipByValue_( const grad = (dy: T, saved: Tensor[]) => { const [$x] = saved; return { - $x: () => dy.where( - $x.greaterEqual(clipValueMin) - .logicalAnd($x.lessEqual(clipValueMax)), - zerosLike(dy)) as T, + x: () => dy.where( + $x.greaterEqual(clipValueMin) + .logicalAnd($x.lessEqual(clipValueMax)), + zerosLike(dy)) as T, }; }; + const inputsToSave = [$x]; + const attr = {min: clipValueMin, max: clipValueMax}; return ENGINE.runKernelFunc((backend, save) => { const res = backend.clip($x, clipValueMin, clipValueMax); save([$x]); return res; - }, {$x}, grad); + }, {x: $x}, grad, 'ClipByValue', attr, inputsToSave); } /** diff --git a/tfjs-core/src/ops/unary_ops_test.ts b/tfjs-core/src/ops/unary_ops_test.ts index fa6f6fb1056..488bfcef657 100644 --- a/tfjs-core/src/ops/unary_ops_test.ts +++ b/tfjs-core/src/ops/unary_ops_test.ts @@ -3124,7 +3124,7 @@ describeWithFlags('clip', ALL_ENVS, () => { expect(f).toThrowError(); }); - it('derivative: 1D tensor', async () => { + it('gradient: 1D tensor', async () => { const min = -1; const max = 2; const x = tf.tensor1d([3, -2, 1]); // Only 1 is not clipped. @@ -3136,7 +3136,7 @@ describeWithFlags('clip', ALL_ENVS, () => { expectArraysClose(await gradients.data(), [0, 0, 500]); }); - it('derivative: 1D tensor with max or min value', async () => { + it('gradient: 1D tensor with max or min value', async () => { const min = -1; const max = 2; const x = tf.tensor1d([-1, 1, 2, 3]); @@ -3148,7 +3148,7 @@ describeWithFlags('clip', ALL_ENVS, () => { expectArraysClose(await gradients.data(), [1, 10, 100, 0]); }); - it('derivative: scalar', async () => { + it('gradient: scalar', async () => { const min = -1; const max = 2; const x = tf.scalar(-10); // Clipped. @@ -3173,7 +3173,7 @@ describeWithFlags('clip', ALL_ENVS, () => { expectArraysClose(await gradients.data(), [0]); }); - it('derivate with primitive as input', async () => { + it('gradient with primitive as input', async () => { const min = -1; const max = 2; const x = -10;