diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 0ef8f64fc2c..a22d9204bd8 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -224,6 +224,7 @@ tfjs_cc_library( ":Relu", ":Relu6", ":ResizeBilinear", + ":Reverse", ":ScatterNd", ":SelectV2", ":Sigmoid", @@ -746,6 +747,15 @@ tfjs_unit_test( ], ) +tfjs_cc_library( + name = "Reverse", + srcs = ["kernels/Reverse.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "ScatterNd", srcs = ["kernels/ScatterNd.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Reverse.cc b/tfjs-backend-wasm/src/cc/kernels/Reverse.cc new file mode 100644 index 00000000000..7aa290639e8 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Reverse.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Reverse(const size_t x_id, const size_t* axes_ptr, + const size_t axes_length, const size_t* out_shape_ptr, + const size_t out_shape_length, const size_t out_id) { + auto out_shape = + std::vector(out_shape_ptr, out_shape_ptr + out_shape_length); + auto axes = std::vector(axes_ptr, axes_ptr + axes_length); + + auto& x_info = backend::get_tensor_info(x_id); + const float* x_buf = x_info.f32(); + + auto& out_info = backend::get_tensor_info_out(out_id); + float* out_buf = out_info.f32_write(); + + size_t x_size = x_info.size; + + const std::vector out_strides = + tfjs::util::compute_strides(out_shape); + + for (size_t i = 0; i < x_size; ++i) { + std::vector in_loc = tfjs::util::offset_to_loc(i, out_strides); + + for (size_t ax_i = 0; ax_i < axes_length; ++ax_i) { + size_t ax = axes[ax_i]; + in_loc[ax] = out_shape[ax] - 1 - in_loc[ax]; + } + + const size_t x_position = tfjs::util::loc_to_offset(in_loc, out_strides); + out_buf[i] = x_buf[x_position]; + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/index.ts b/tfjs-backend-wasm/src/index.ts index 5581dd7e02c..a2c1b2a4d48 100644 --- a/tfjs-backend-wasm/src/index.ts +++ b/tfjs-backend-wasm/src/index.ts @@ -16,6 +16,7 @@ */ import './kernels/all_kernels'; +import './register_all_kernels'; export {BackendWasm, setWasmPath} from './backend_wasm'; export {version as version_wasm} from './version'; diff --git a/tfjs-backend-wasm/src/kernels/Identity.ts b/tfjs-backend-wasm/src/kernels/Identity.ts new file mode 100644 index 00000000000..edfdb824914 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Identity.ts @@ -0,0 +1,37 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Identity, IdentityInputs, KernelFunc, registerKernel} from '@tensorflow/tfjs-core'; +import {TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +export function identity(args: {inputs: IdentityInputs, backend: BackendWasm}): + TensorInfo { + const {inputs: {x}, backend} = args; + const out = backend.makeOutput(x.shape, x.dtype); + const inVals = backend.typedArrayFromHeap(x); + const outVals = backend.typedArrayFromHeap(out); + outVals.set(inVals); + return out; +} + +registerKernel({ + kernelName: Identity, + backendName: 'wasm', + kernelFunc: identity as {} as KernelFunc, +}); diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts new file mode 100644 index 00000000000..836ac304cb4 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -0,0 +1,71 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +import {identity} from './Identity'; +import {reshape} from './Reshape'; + +let wasmReverse: ( + xId: number, axes: Uint8Array, axesLength: number, outShape: Uint8Array, + outShapeLength: number, outId: number) => void; + +function setup(backend: BackendWasm) { + wasmReverse = backend.wasm.cwrap(Reverse, null, [ + 'number', // x_id + 'array', // axes + 'number', // axes_length + 'array', // out_shape + 'number', // out_shape_length + 'number' // out_id + ]); +} + +export function reverse( + args: {inputs: ReverseInputs, backend: BackendWasm, attrs: ReverseAttrs}): + TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs; + const {dims} = attrs; + + const axes = util.parseAxisParam(dims, x.shape); + + if (x.shape.length === 0) { + return identity({inputs: {x}, backend}); + } + + const out = backend.makeOutput(x.shape, x.dtype); + const xId = backend.dataIdMap.get(x.dataId).id; + const outId = backend.dataIdMap.get(out.dataId).id; + + const axesBytes = new Uint8Array(new Int32Array(axes).buffer); + const outShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); + + wasmReverse( + xId, axesBytes, axes.length, outShapeBytes, x.shape.length, outId); + + return reshape({inputs: {x: out}, attrs: {shape: x.shape}, backend}); +} + +export const reverseConfig: KernelConfig = { + kernelName: Reverse, + backendName: 'wasm', + kernelFunc: reverse as {} as KernelFunc, + setupFunc: setup +}; diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index ee032cfc119..c8cea76c002 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -45,6 +45,7 @@ import './Gather'; import './GatherNd'; import './Greater'; import './GreaterEqual'; +import './Identity'; import './Less'; import './LessEqual'; import './Log'; @@ -68,6 +69,7 @@ import './Relu'; import './Relu6'; import './Reshape'; import './ResizeBilinear'; +import './Reverse'; import './Rsqrt'; import './SelectV2'; import './ScatterNd'; diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts new file mode 100644 index 00000000000..87c151df493 --- /dev/null +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +// We explicitly import the modular kernels so they get registered in the +// global registry when we compile the library. A modular build would replace +// the contents of this file and import only the kernels that are needed. +import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; + +import {reverseConfig} from './kernels/Reverse'; + +// List all kernel configs here +const kernelConfigs: KernelConfig[] = [reverseConfig]; + +for (const kernelConfig of kernelConfigs) { + registerKernel(kernelConfig); +} diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 0c14b49c728..8483f83f2d6 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -308,6 +308,7 @@ const TEST_FILTERS: TestFilter[] = [ 'axis=0', // Reduction not supported along inner dimensions. ] }, + {startsWith: 'reverse'}, {startsWith: 'sum '}, { startsWith: 'logicalAnd ', diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index 71756dda83b..2dd95b94200 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -71,7 +71,7 @@ export * from './globals'; export * from './kernel_registry'; export {customGrad, grad, grads, valueAndGrad, valueAndGrads, variableGrads} from './gradients'; -export {TimingInfo, MemoryInfo} from './engine'; +export {TimingInfo, MemoryInfo, ForwardFunc} from './engine'; export {Environment, env, ENV} from './environment'; export {Platform} from './platforms/platform'; diff --git a/tfjs-core/src/ops/clone.ts b/tfjs-core/src/ops/clone.ts index 4cffbd9bae8..cfbd45003fe 100644 --- a/tfjs-core/src/ops/clone.ts +++ b/tfjs-core/src/ops/clone.ts @@ -16,9 +16,9 @@ */ import {ENGINE} from '../engine'; -import {Identity} from '../kernel_names'; - +import {Identity, IdentityInputs} from '../kernel_names'; import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; @@ -42,9 +42,12 @@ function clone_(x: T|TensorLike): T { const forward = () => ENGINE.makeTensorFromDataId($x.dataId, $x.shape, $x.dtype) as T; + const inputs: IdentityInputs = {x: $x}; + // Note this op is called tf.identity in python. Hence the kernel name used // here. - return ENGINE.runKernelFunc(forward, {x: $x}, null /* grad */, Identity); + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, Identity); } export const clone = op({clone_});