diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index e4d0f8d8cd5..fa8c56fb933 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -676,12 +676,21 @@ tfjs_cc_library( tfjs_cc_library( name = "PadV2", srcs = ["kernels/PadV2.cc"], + hdrs = ["kernels/PadV2.h"], deps = [ ":backend", ":util", ], ) +tfjs_unit_test( + name = "PadV2_test", + srcs = ["kernels/PadV2_test.cc"], + deps = [ + ":PadV2", + ], +) + tfjs_cc_library( name = "Pow", srcs = ["kernels/Pow.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/PadV2.cc b/tfjs-backend-wasm/src/cc/kernels/PadV2.cc index 7d1205e9603..c7fed57b7de 100644 --- a/tfjs-backend-wasm/src/cc/kernels/PadV2.cc +++ b/tfjs-backend-wasm/src/cc/kernels/PadV2.cc @@ -16,11 +16,28 @@ #include #endif +#include #include +#include +#include +#include +#include +#include #include "src/cc/backend.h" #include "src/cc/util.h" +namespace { +// We use std::tuple as the cache key as it implements the compare operator +// needed for std::map. +typedef std::tuple OperatorCacheKey; + +// The operator cache maps the weights id to the xnn_operator_t instantiated for +// this set of weights. +std::map operator_cache; + +} // namespace + namespace { using tfjs::util::compute_strides; @@ -120,13 +137,14 @@ void pad_4d(const T* x_data, size_t x_shape[4], size_t paddings[8], // Generic pad implementation for n-dim tensors. template void slow_pad_nd(const T* x_data, const std::vector& x_shape, - const std::vector& paddings, const T pad_value, + const std::vector& pre_paddings, + const std::vector& post_paddings, const T pad_value, T* out_data) { const size_t rank = x_shape.size(); std::vector out_shape(rank); for (size_t i = 0; i < rank; ++i) { - const size_t pad_left = paddings[i * 2]; - const size_t pad_right = paddings[i * 2 + 1]; + const size_t pad_left = pre_paddings[i]; + const size_t pad_right = post_paddings[i]; out_shape[i] = x_shape[i] + pad_left + pad_right; } const auto& in_strides = compute_strides(x_shape); @@ -139,7 +157,7 @@ void slow_pad_nd(const T* x_data, const std::vector& x_shape, for (size_t i = 0; i < in_size; ++i) { auto out_loc = offset_to_loc(i, in_strides); for (size_t j = 0; j < rank; ++j) { - out_loc[j] += paddings[j * 2]; + out_loc[j] += pre_paddings[j]; } const size_t out_offset = loc_to_offset(out_loc, out_strides); out_data[out_offset] = x_data[i]; @@ -148,7 +166,9 @@ void slow_pad_nd(const T* x_data, const std::vector& x_shape, template void pad(const T* x_data, const std::vector& x_shape, - const std::vector& paddings, const T pad_value, T* out_data) { + const std::vector& pre_paddings, + const std::vector& post_paddings, const T pad_value, + T* out_data) { const size_t rank = x_shape.size(); if (rank <= 4) { // Expand the shape to be 4d. @@ -165,8 +185,8 @@ void pad(const T* x_data, const std::vector& x_shape, for (size_t i = 0; i < rank; ++i) { size_t j = i + rank_shift; - const size_t pad_left = paddings[i * 2]; - const size_t pad_right = paddings[i * 2 + 1]; + const size_t pad_left = pre_paddings[i]; + const size_t pad_right = post_paddings[i]; x_shape_4d[j] = x_shape[i]; out_shape_4d[j] = x_shape[i] + pad_left + pad_right; paddings_4d[j * 2] = pad_left; @@ -174,7 +194,8 @@ void pad(const T* x_data, const std::vector& x_shape, } pad_4d(x_data, x_shape_4d, paddings_4d, pad_value, out_shape_4d, out_data); } else { - slow_pad_nd(x_data, x_shape, paddings, pad_value, out_data); + slow_pad_nd(x_data, x_shape, pre_paddings, post_paddings, pad_value, + out_data); } } @@ -190,26 +211,64 @@ EMSCRIPTEN_KEEPALIVE #endif void PadV2(const size_t x_id, const size_t* x_shape_ptr, const size_t x_shape_length, const DType dtype, - const size_t* paddings_ptr, const float pad_value, - const size_t out_id) { + const size_t* pre_paddings_ptr, const size_t* post_paddings_ptr, + const float pad_value, const size_t out_id) { auto x_shape = std::vector(x_shape_ptr, x_shape_ptr + x_shape_length); - const size_t paddings_length = x_shape_length * 2; - auto paddings = - std::vector(paddings_ptr, paddings_ptr + paddings_length); + auto pre_paddings = + std::vector(pre_paddings_ptr, pre_paddings_ptr + x_shape_length); + auto post_paddings = std::vector(post_paddings_ptr, + post_paddings_ptr + x_shape_length); + auto& x_info = backend::get_tensor_info(x_id); auto& out_info = backend::get_tensor_info_out(out_id); switch (dtype) { - case DType::float32: - pad(x_info.f32(), x_shape, paddings, pad_value, - out_info.f32_write()); + case DType::float32: { + xnn_operator_t pad_op = nullptr; + const uint32_t flags = 0; + + OperatorCacheKey cache_key = {pad_value, flags}; + + auto operator_cache_idx = operator_cache.find(cache_key); + if (operator_cache_idx == operator_cache.end()) { + xnn_status status = + xnn_create_constant_pad_nd_x32(&pad_value, flags, &pad_op); + if (status != xnn_status_success) { + tfjs::util::warn( + "XNN status for xnn_create_constant_pad_nd_x32 is not " + "successful. Got status %d. Use -c dbg to see XNN logs.", + status); + return; + } + + operator_cache.insert({cache_key, pad_op}); + + tfjs::backend::xnn_operator_count++; + } else { + pad_op = operator_cache_idx->second; + } + + xnn_status status = xnn_setup_constant_pad_nd_x32( + pad_op, x_shape_length, x_shape_ptr, pre_paddings_ptr, + post_paddings_ptr, x_info.f32(), out_info.f32_write(), + nullptr /* threadpool */); + if (status != xnn_status_success) { + tfjs::util::warn( + "XNN status for xnn_setup_constant_pad_nd_x32 is not " + "successful. Got status %d. Use -c dbg to see XNN logs.", + status); + return; + } + + xnn_run_operator(pad_op, nullptr /* threadpool */); break; + } case DType::int32: - pad(x_info.i32(), x_shape, paddings, + pad(x_info.i32(), x_shape, pre_paddings, post_paddings, static_cast(pad_value), out_info.i32_write()); break; case DType::boolean: - pad(x_info.b(), x_shape, paddings, static_cast(pad_value), - out_info.b_write()); + pad(x_info.b(), x_shape, pre_paddings, post_paddings, + static_cast(pad_value), out_info.b_write()); break; default: util::warn("Pad for tensor id %d failed. Unknown dtype % d ", x_id, diff --git a/tfjs-backend-wasm/src/cc/kernels/PadV2.h b/tfjs-backend-wasm/src/cc/kernels/PadV2.h new file mode 100644 index 00000000000..69eee1070f1 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/PadV2.h @@ -0,0 +1,35 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifndef KERNELS_PADV2_H_ +#define KERNELS_PADV2_H_ + +#include + +#include "src/cc/backend.h" + +namespace tfjs { +namespace wasm { +extern "C" { + +void PadV2(const size_t x_id, const size_t* x_shape_ptr, + const size_t x_shape_length, const DType dtype, + const size_t* pre_paddings_ptr, const size_t* post_paddings_ptr, + const float pad_value, const size_t out_id); +} + +} // namespace wasm +} // namespace tfjs + +#endif // KERNELS_PADV2_H_ diff --git a/tfjs-backend-wasm/src/cc/kernels/PadV2_test.cc b/tfjs-backend-wasm/src/cc/kernels/PadV2_test.cc new file mode 100644 index 00000000000..411f2ec70f2 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/PadV2_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#include + +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/kernels/PadV2.h" + +TEST(PADV2, xnn_operator_lifetime) { + tfjs::wasm::init(); + + ASSERT_EQ(0, tfjs::backend::num_tensors()); + + const size_t x0_id = 1; + const size_t x1_id = 2; + const size_t x_size = 4; + float x_values[x_size] = {1, 2, 2, 2}; + + const size_t out_id = 3; + const size_t out_size = 8; + float out_values[out_size] = {0, 0, 0, 0, 0, 0, 0, 0}; + + tfjs::wasm::register_tensor(x0_id, x_size, x_values); + tfjs::wasm::register_tensor(x1_id, x_size, x_values); + tfjs::wasm::register_tensor(out_id, out_size, out_values); + + ASSERT_EQ(3, tfjs::backend::num_tensors()); + ASSERT_EQ(0, tfjs::backend::xnn_operator_count); + + const size_t x_rank = 4; + + std::vector x_shape = {1, 1, 1, 4}; + std::vector pre_paddings = {0, 0, 0, 0}; + std::vector post_paddings = {0, 0, 0, 0}; + + const float pad_value = 0.0; + + size_t* x_shape_ptr = x_shape.data(); + size_t* pre_paddings_ptr = pre_paddings.data(); + size_t* post_paddings_ptr = post_paddings.data(); + + const DType dtype = float32; + + // One new xnn_operator should be created for the first call to + // PadV2. + tfjs::wasm::PadV2(x0_id, x_shape_ptr, x_rank, dtype, pre_paddings_ptr, + post_paddings_ptr, pad_value, out_id); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + // // No new xnn_operators should be created for the second call to + // // PadV2 with the same arguments. + tfjs::wasm::PadV2(x0_id, x_shape_ptr, x_rank, dtype, pre_paddings_ptr, + post_paddings_ptr, pad_value, out_id); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + // // No new xnn_operators should be created for the second call to + // // PadV2 with a new x id but same arguments. + tfjs::wasm::PadV2(x1_id, x_shape_ptr, x_rank, dtype, pre_paddings_ptr, + post_paddings_ptr, pad_value, out_id); + ASSERT_EQ(1, tfjs::backend::xnn_operator_count); + + // // One new xnn_operator should be created for another call to PadV2 + // // with a different pad_value. + const float new_pad_value = 0.5; + tfjs::wasm::PadV2(x0_id, x_shape_ptr, x_rank, dtype, pre_paddings_ptr, + post_paddings_ptr, new_pad_value, out_id); + ASSERT_EQ(2, tfjs::backend::xnn_operator_count); + + tfjs::wasm::dispose(); +} diff --git a/tfjs-backend-wasm/src/kernels/PadV2.ts b/tfjs-backend-wasm/src/kernels/PadV2.ts index 47b60dd449b..10ee637c34f 100644 --- a/tfjs-backend-wasm/src/kernels/PadV2.ts +++ b/tfjs-backend-wasm/src/kernels/PadV2.ts @@ -15,38 +15,32 @@ * ============================================================================= */ -import {KernelConfig, NamedAttrMap, NamedTensorInfoMap, TensorInfo} from '@tensorflow/tfjs-core'; +import {KernelConfig, KernelFunc, PadV2, PadV2Attrs, PadV2Inputs} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; import {CppDType} from './types'; -interface PadInputs extends NamedTensorInfoMap { - x: TensorInfo; -} - -interface PadAttrs extends NamedAttrMap { - paddings: Array<[number, number]>; - constantValue: number; -} - let wasmPadV2: ( xId: number, xShapeBytes: Uint8Array, xShapeLength: number, xDtype: number, - paddingsBytes: Uint8Array, constantValue: number, outId: number) => void; + prePaddingsBytes: Uint8Array, postPaddingsBytes: Uint8Array, + constantValue: number, outId: number) => void; function setup(backend: BackendWasm) { - wasmPadV2 = backend.wasm.cwrap('PadV2', null /* void */, [ + wasmPadV2 = backend.wasm.cwrap(PadV2, null /* void */, [ 'number', // xId 'array', // x.shape 'number', // x.shape.length 'number', // x.dtype - 'array', // paddings + 'array', // pre-paddings + 'array', // post-paddings 'number', // constantValue 'number', // outId ]); } -function pad(args: {inputs: PadInputs, backend: BackendWasm, attrs: PadAttrs}) { +function pad( + args: {inputs: PadV2Inputs, backend: BackendWasm, attrs: PadV2Attrs}) { const {inputs: {x}, backend, attrs: {paddings, constantValue}} = args; const outShape = paddings.map( @@ -55,17 +49,23 @@ function pad(args: {inputs: PadInputs, backend: BackendWasm, attrs: PadAttrs}) { const out = backend.makeOutput(outShape, x.dtype); const outId = backend.dataIdMap.get(out.dataId).id; const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); - const paddingsFlat = [].concat(...paddings); - const paddingsBytes = new Uint8Array(new Int32Array(paddingsFlat).buffer); + + const prePaddingsFlat = paddings.map(padTuple => padTuple[0]); + const postPaddingsFlat = paddings.map(padTuple => padTuple[1]); + const prePaddingsBytes = + new Uint8Array(new Int32Array(prePaddingsFlat).buffer); + const postPaddingsBytes = + new Uint8Array(new Int32Array(postPaddingsFlat).buffer); + wasmPadV2( - xId, xShapeBytes, x.shape.length, CppDType[x.dtype], paddingsBytes, - constantValue, outId); + xId, xShapeBytes, x.shape.length, CppDType[x.dtype], prePaddingsBytes, + postPaddingsBytes, constantValue, outId); return out; } export const padV2Config: KernelConfig = { - kernelName: 'PadV2', + kernelName: PadV2, backendName: 'wasm', - kernelFunc: pad, + kernelFunc: pad as {} as KernelFunc, setupFunc: setup };