From f5ed694a5d7a001b8a68e654947e8a70b09bd90f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 26 Jun 2020 14:21:40 -0400 Subject: [PATCH 01/13] init --- tfjs-backend-wasm/src/kernels/Reverse.ts | 49 ++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 tfjs-backend-wasm/src/kernels/Reverse.ts diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts new file mode 100644 index 00000000000..94677c41f70 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, buffer, NamedAttrMap, NamedTensorInfoMap, registerKernel, Reverse, ReverseAttrs, ReverseInputs, slice_util, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +export function reverse(args: { + inputs: NamedTensorInfoMap, + backend: BackendWasm, + attrs: NamedAttrMap +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs as {} as ReverseInputs; + const {dims} = attrs as {} as ReverseAttrs; + + const axes = util.parseAxisParam(dims, x.shape); + + const out = backend.makeOutput(x.shape, x.dtype); + const xVals = backend.typedArrayFromHeap(x); + const outVals = backend.typedArrayFromHeap(out); + const outBuf = buffer(x.shape, x.dtype, outVals); + + for (let i = 0; i < outVals.length; i++) { + const outLoc = outBuf.indexToLoc(i); + const inLoc = outLoc.slice(); + axes.forEach(ax => inLoc[ax] = x.shape[ax] - 1 - inLoc[ax]); + // let inPos = 0; + outBuf.set(xVals[0], ...outLoc); + } + + return out; +} + +registerKernel({kernelName: Reverse, backendName: 'wasm', kernelFunc: reverse}); From 73f6d28431d0e73663c00fcce6ccadc0b01699eb Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Sun, 28 Jun 2020 13:31:59 -0400 Subject: [PATCH 02/13] clean --- tfjs-backend-wasm/src/kernels/Reverse.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts index 94677c41f70..fcb57b29510 100644 --- a/tfjs-backend-wasm/src/kernels/Reverse.ts +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {backend_util, buffer, NamedAttrMap, NamedTensorInfoMap, registerKernel, Reverse, ReverseAttrs, ReverseInputs, slice_util, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {buffer, NamedAttrMap, NamedTensorInfoMap, registerKernel, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; From 30c292da09a676860bdd8f6d0c1ae5566e437277 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 29 Jun 2020 08:02:24 -0400 Subject: [PATCH 03/13] add reshape --- tfjs-backend-wasm/src/kernels/Reverse.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts index fcb57b29510..33f9b54c94a 100644 --- a/tfjs-backend-wasm/src/kernels/Reverse.ts +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -18,6 +18,7 @@ import {buffer, NamedAttrMap, NamedTensorInfoMap, registerKernel, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; +import {reshape} from './Reshape'; export function reverse(args: { inputs: NamedTensorInfoMap, @@ -30,6 +31,8 @@ export function reverse(args: { const axes = util.parseAxisParam(dims, x.shape); + // TODO: ADD CLONE + const out = backend.makeOutput(x.shape, x.dtype); const xVals = backend.typedArrayFromHeap(x); const outVals = backend.typedArrayFromHeap(out); @@ -43,7 +46,7 @@ export function reverse(args: { outBuf.set(xVals[0], ...outLoc); } - return out; + return reshape({inputs: {x: out}, attrs: {shape: x.shape}, backend}); } registerKernel({kernelName: Reverse, backendName: 'wasm', kernelFunc: reverse}); From e8fff8f85c47df86f1354f13a6ebfa0c426ee58e Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 29 Jun 2020 08:42:21 -0400 Subject: [PATCH 04/13] reverse --- tfjs-backend-wasm/src/cc/BUILD | 10 +++ tfjs-backend-wasm/src/cc/kernels/Reverse.cc | 69 ++++++++++++++++++++ tfjs-backend-wasm/src/kernels/Reverse.ts | 44 +++++++++---- tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + 4 files changed, 111 insertions(+), 13 deletions(-) create mode 100644 tfjs-backend-wasm/src/cc/kernels/Reverse.cc diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 0ef8f64fc2c..a22d9204bd8 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -224,6 +224,7 @@ tfjs_cc_library( ":Relu", ":Relu6", ":ResizeBilinear", + ":Reverse", ":ScatterNd", ":SelectV2", ":Sigmoid", @@ -746,6 +747,15 @@ tfjs_unit_test( ], ) +tfjs_cc_library( + name = "Reverse", + srcs = ["kernels/Reverse.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "ScatterNd", srcs = ["kernels/ScatterNd.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Reverse.cc b/tfjs-backend-wasm/src/cc/kernels/Reverse.cc new file mode 100644 index 00000000000..d6ae4ab18f5 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Reverse.cc @@ -0,0 +1,69 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Reverse(const size_t x_id, const size_t* axes_ptr, + const size_t axes_length, const size_t* out_shape_ptr, + const size_t out_shape_length, const size_t out_id) { + auto out_shape = + std::vector(out_shape_ptr, out_shape_ptr + out_shape_length); + auto axes = std::vector(axes_ptr, axes_ptr + axes_length); + + auto& x_info = backend::get_tensor_info(x_id); + const float* x_buf = x_info.f32(); + + size_t x_size = x_info.size; + + const std::vector out_strides = + tfjs::util::compute_strides(out_shape); + + for (size_t i = 0; i < x_size; ++i) { + const std::vector out_loc = + tfjs::util::offset_to_loc(i, out_strides); + + for (size_t ax_i = 0; ax_i < axes_length; ++ax_i) { + size_t ax = axes[ax_i]; + } + + const std::vector in_loc = out_loc; + for (size_t ax_i = 0; ax_i < axes_length; ++ax_i) { + size_t ax = axes[ax_i]; + in_loc[ax] = out_shape[ax] - 1 - in_loc[ax]; + } + + const size_t x_position = tfjs::util::loc_to_offset(in_loc, out_strides); + out[i] = x_buf[x_position]; + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts index 33f9b54c94a..2dc0279917b 100644 --- a/tfjs-backend-wasm/src/kernels/Reverse.ts +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -15,11 +15,26 @@ * ============================================================================= */ -import {buffer, NamedAttrMap, NamedTensorInfoMap, registerKernel, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; import {reshape} from './Reshape'; +let wasmReverse: ( + xId: number, axes: Uint8Array, axesLength: number, outShape: Uint8Array, + outShapeLength: number, outId: number) => void; + +function setup(backend: BackendWasm) { + wasmReverse = backend.wasm.cwrap(Reverse, null, [ + 'number', // x_id + 'array', // axes + 'number', // axes_length + 'array', // out_shape + 'number', // out_shape_length + 'number' // out_id + ]); +} + export function reverse(args: { inputs: NamedTensorInfoMap, backend: BackendWasm, @@ -34,19 +49,22 @@ export function reverse(args: { // TODO: ADD CLONE const out = backend.makeOutput(x.shape, x.dtype); - const xVals = backend.typedArrayFromHeap(x); - const outVals = backend.typedArrayFromHeap(out); - const outBuf = buffer(x.shape, x.dtype, outVals); - - for (let i = 0; i < outVals.length; i++) { - const outLoc = outBuf.indexToLoc(i); - const inLoc = outLoc.slice(); - axes.forEach(ax => inLoc[ax] = x.shape[ax] - 1 - inLoc[ax]); - // let inPos = 0; - outBuf.set(xVals[0], ...outLoc); - } + const xId = backend.dataIdMap.get(x.dataId).id; + const outId = backend.dataIdMap.get(out.dataId).id; + + const axesBytes = new Uint8Array(new Int32Array(axes).buffer); + + const outShapeBytes = new Uint8Array(new Int32Array(x.shape)); + + wasmReverse( + xId, axesBytes, axes.length, outShapeBytes, x.shape.length, outId); return reshape({inputs: {x: out}, attrs: {shape: x.shape}, backend}); } -registerKernel({kernelName: Reverse, backendName: 'wasm', kernelFunc: reverse}); +registerKernel({ + kernelName: Reverse, + backendName: 'wasm', + kernelFunc: reverse, + setupFunc: setup +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index ee032cfc119..1f304465401 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -68,6 +68,7 @@ import './Relu'; import './Relu6'; import './Reshape'; import './ResizeBilinear'; +import './Reverse'; import './Rsqrt'; import './SelectV2'; import './ScatterNd'; From 870a946d99f7cb5ff7e82dd4a53abcf3e8c54294 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 29 Jun 2020 09:03:29 -0400 Subject: [PATCH 05/13] setup --- tfjs-backend-wasm/src/cc/kernels/Reverse.cc | 11 +++---- tfjs-backend-wasm/src/index_test.ts | 34 +++++++++++++++++++-- tfjs-backend-wasm/src/kernels/Reverse.ts | 3 +- tfjs-backend-wasm/src/setup_test.ts | 1 + 4 files changed, 40 insertions(+), 9 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/Reverse.cc b/tfjs-backend-wasm/src/cc/kernels/Reverse.cc index d6ae4ab18f5..7f1459decc2 100644 --- a/tfjs-backend-wasm/src/cc/kernels/Reverse.cc +++ b/tfjs-backend-wasm/src/cc/kernels/Reverse.cc @@ -40,6 +40,9 @@ void Reverse(const size_t x_id, const size_t* axes_ptr, auto& x_info = backend::get_tensor_info(x_id); const float* x_buf = x_info.f32(); + auto& out_info = backend::get_tensor_info_out(out_id); + float* out_buf = out_info.f32_write(); + size_t x_size = x_info.size; const std::vector out_strides = @@ -49,18 +52,14 @@ void Reverse(const size_t x_id, const size_t* axes_ptr, const std::vector out_loc = tfjs::util::offset_to_loc(i, out_strides); - for (size_t ax_i = 0; ax_i < axes_length; ++ax_i) { - size_t ax = axes[ax_i]; - } - - const std::vector in_loc = out_loc; + std::vector in_loc = out_loc; for (size_t ax_i = 0; ax_i < axes_length; ++ax_i) { size_t ax = axes[ax_i]; in_loc[ax] = out_shape[ax] - 1 - in_loc[ax]; } const size_t x_position = tfjs::util::loc_to_offset(in_loc, out_strides); - out[i] = x_buf[x_position]; + out_buf[i] = x_buf[x_position]; } } diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index d7f2bd6e914..27f68c9a6d3 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -73,8 +73,8 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { }, 100); // Silences backend registration warnings. - spyOn(console, 'warn'); - spyOn(console, 'log'); + // spyOn(console, 'warn'); + // spyOn(console, 'log'); }); afterEach(() => { @@ -138,4 +138,34 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { expect(() => setWasmPath('too/late')) .toThrowError(/The WASM backend was already initialized. Make sure/); }); + + it('accepts a tensor-like object', async () => { + const input = [1, 2, 3]; + const result = tf.reverse(input); + expect(result.shape).toEqual([3]); + const data = await result.data(); + console.log(data); + // expectArraysClose(await result.data(), [3, 2, 1]); + }); + + fit('reverse a 4D array at axis [0]', async () => { + const shape: [number, number, number, number] = [3, 2, 3, 4]; + const data = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71 + ]; + const input = tf.tensor4d(data, shape); + const result = tf.reverse4d(input, [0]); + expect(result.shape).toEqual(input.shape); + const out = await result.data(); + console.log(Array.from(out)); + // expectArraysClose(await result.data(), [ + // 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, + // 66, 67, 68, 69, 70, 71, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + // 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 0, 1, 2, 3, 4, 5, + // 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 + // ]); + }); }); diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts index 2dc0279917b..fe700fd08c9 100644 --- a/tfjs-backend-wasm/src/kernels/Reverse.ts +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -53,9 +53,10 @@ export function reverse(args: { const outId = backend.dataIdMap.get(out.dataId).id; const axesBytes = new Uint8Array(new Int32Array(axes).buffer); - const outShapeBytes = new Uint8Array(new Int32Array(x.shape)); + console.log(axes, x.shape); + wasmReverse( xId, axesBytes, axes.length, outShapeBytes, x.shape.length, outId); diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 0c14b49c728..8483f83f2d6 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -308,6 +308,7 @@ const TEST_FILTERS: TestFilter[] = [ 'axis=0', // Reduction not supported along inner dimensions. ] }, + {startsWith: 'reverse'}, {startsWith: 'sum '}, { startsWith: 'logicalAnd ', From efee4bcadc31746c43189f3fd43c59b49c7b6578 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 29 Jun 2020 09:46:53 -0400 Subject: [PATCH 06/13] fix --- tfjs-backend-wasm/src/index_test.ts | 30 ------------------------ tfjs-backend-wasm/src/kernels/Reverse.ts | 4 +--- 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index 27f68c9a6d3..67be9de62af 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -138,34 +138,4 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { expect(() => setWasmPath('too/late')) .toThrowError(/The WASM backend was already initialized. Make sure/); }); - - it('accepts a tensor-like object', async () => { - const input = [1, 2, 3]; - const result = tf.reverse(input); - expect(result.shape).toEqual([3]); - const data = await result.data(); - console.log(data); - // expectArraysClose(await result.data(), [3, 2, 1]); - }); - - fit('reverse a 4D array at axis [0]', async () => { - const shape: [number, number, number, number] = [3, 2, 3, 4]; - const data = [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, - 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, - 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71 - ]; - const input = tf.tensor4d(data, shape); - const result = tf.reverse4d(input, [0]); - expect(result.shape).toEqual(input.shape); - const out = await result.data(); - console.log(Array.from(out)); - // expectArraysClose(await result.data(), [ - // 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, - // 66, 67, 68, 69, 70, 71, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, - // 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 0, 1, 2, 3, 4, 5, - // 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 - // ]); - }); }); diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts index fe700fd08c9..77abb014543 100644 --- a/tfjs-backend-wasm/src/kernels/Reverse.ts +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -53,9 +53,7 @@ export function reverse(args: { const outId = backend.dataIdMap.get(out.dataId).id; const axesBytes = new Uint8Array(new Int32Array(axes).buffer); - const outShapeBytes = new Uint8Array(new Int32Array(x.shape)); - - console.log(axes, x.shape); + const outShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); wasmReverse( xId, axesBytes, axes.length, outShapeBytes, x.shape.length, outId); From 20ca72b736f9dff82678982a47a91a612b5f58a9 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 29 Jun 2020 09:48:25 -0400 Subject: [PATCH 07/13] clean --- tfjs-backend-wasm/src/index_test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index 67be9de62af..d7f2bd6e914 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -73,8 +73,8 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { }, 100); // Silences backend registration warnings. - // spyOn(console, 'warn'); - // spyOn(console, 'log'); + spyOn(console, 'warn'); + spyOn(console, 'log'); }); afterEach(() => { From 679c90e55eccf03dacc69cd5ac9e8f68385de006 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 29 Jun 2020 10:09:04 -0400 Subject: [PATCH 08/13] add clone --- tfjs-backend-wasm/src/backend_wasm.ts | 6 +++++- tfjs-backend-wasm/src/kernels/Reverse.ts | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-wasm/src/backend_wasm.ts b/tfjs-backend-wasm/src/backend_wasm.ts index 0299073da66..3ecdd8c4fde 100644 --- a/tfjs-backend-wasm/src/backend_wasm.ts +++ b/tfjs-backend-wasm/src/backend_wasm.ts @@ -16,7 +16,7 @@ */ import './flags_wasm'; -import {backend_util, BackendTimingInfo, DataStorage, DataType, engine, env, KernelBackend, registerBackend, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, BackendTimingInfo, DataStorage, DataType, engine, env, KernelBackend, registerBackend, Tensor, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasmModule, WasmFactoryConfig} from '../wasm-out/tfjs-backend-wasm'; import wasmFactorySimd from '../wasm-out/tfjs-backend-wasm-simd.js'; @@ -64,6 +64,10 @@ export class BackendWasm extends KernelBackend { return {kernelMs}; } + clone(x: TensorInfo): Tensor { + return engine().makeTensorFromDataId(x.dataId, x.shape, x.dtype); + } + move( dataId: DataId, values: backend_util.BackendValues, shape: number[], dtype: DataType): void { diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts index 77abb014543..571f4a2bab7 100644 --- a/tfjs-backend-wasm/src/kernels/Reverse.ts +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -46,7 +46,9 @@ export function reverse(args: { const axes = util.parseAxisParam(dims, x.shape); - // TODO: ADD CLONE + if (x.shape.length === 0) { + return backend.clone(x); + } const out = backend.makeOutput(x.shape, x.dtype); const xId = backend.dataIdMap.get(x.dataId).id; From a8fd5392e5c7211d1c7c9f5325d1ef6db07a9288 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 29 Jun 2020 16:50:10 -0400 Subject: [PATCH 09/13] comments --- tfjs-backend-wasm/src/cc/kernels/Reverse.cc | 4 +--- tfjs-backend-wasm/src/kernels/Reverse.ts | 17 ++++++++--------- tfjs-core/src/index.ts | 2 +- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/Reverse.cc b/tfjs-backend-wasm/src/cc/kernels/Reverse.cc index 7f1459decc2..7aa290639e8 100644 --- a/tfjs-backend-wasm/src/cc/kernels/Reverse.cc +++ b/tfjs-backend-wasm/src/cc/kernels/Reverse.cc @@ -49,10 +49,8 @@ void Reverse(const size_t x_id, const size_t* axes_ptr, tfjs::util::compute_strides(out_shape); for (size_t i = 0; i < x_size; ++i) { - const std::vector out_loc = - tfjs::util::offset_to_loc(i, out_strides); + std::vector in_loc = tfjs::util::offset_to_loc(i, out_strides); - std::vector in_loc = out_loc; for (size_t ax_i = 0; ax_i < axes_length; ++ax_i) { size_t ax = axes[ax_i]; in_loc[ax] = out_shape[ax] - 1 - in_loc[ax]; diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts index 571f4a2bab7..6d0aa6441cb 100644 --- a/tfjs-backend-wasm/src/kernels/Reverse.ts +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -15,9 +15,10 @@ * ============================================================================= */ -import {NamedAttrMap, NamedTensorInfoMap, registerKernel, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {ForwardFunc, registerKernel, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; + import {reshape} from './Reshape'; let wasmReverse: ( @@ -35,14 +36,12 @@ function setup(backend: BackendWasm) { ]); } -export function reverse(args: { - inputs: NamedTensorInfoMap, - backend: BackendWasm, - attrs: NamedAttrMap -}): TensorInfo { +export function reverse( + args: {inputs: ReverseInputs, backend: BackendWasm, attrs: ReverseAttrs}): + TensorInfo { const {inputs, backend, attrs} = args; - const {x} = inputs as {} as ReverseInputs; - const {dims} = attrs as {} as ReverseAttrs; + const {x} = inputs; + const {dims} = attrs; const axes = util.parseAxisParam(dims, x.shape); @@ -66,6 +65,6 @@ export function reverse(args: { registerKernel({ kernelName: Reverse, backendName: 'wasm', - kernelFunc: reverse, + kernelFunc: reverse as ForwardFunc, setupFunc: setup }); diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index 71756dda83b..2dd95b94200 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -71,7 +71,7 @@ export * from './globals'; export * from './kernel_registry'; export {customGrad, grad, grads, valueAndGrad, valueAndGrads, variableGrads} from './gradients'; -export {TimingInfo, MemoryInfo} from './engine'; +export {TimingInfo, MemoryInfo, ForwardFunc} from './engine'; export {Environment, env, ENV} from './environment'; export {Platform} from './platforms/platform'; From 36caabfb118ab24884371164916100853e22f120 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 29 Jun 2020 16:55:03 -0400 Subject: [PATCH 10/13] add register all --- tfjs-backend-wasm/src/index.ts | 1 + tfjs-backend-wasm/src/kernels/Reverse.ts | 6 ++-- tfjs-backend-wasm/src/register_all_kernels.ts | 29 +++++++++++++++++++ 3 files changed, 33 insertions(+), 3 deletions(-) create mode 100644 tfjs-backend-wasm/src/register_all_kernels.ts diff --git a/tfjs-backend-wasm/src/index.ts b/tfjs-backend-wasm/src/index.ts index 5581dd7e02c..a2c1b2a4d48 100644 --- a/tfjs-backend-wasm/src/index.ts +++ b/tfjs-backend-wasm/src/index.ts @@ -16,6 +16,7 @@ */ import './kernels/all_kernels'; +import './register_all_kernels'; export {BackendWasm, setWasmPath} from './backend_wasm'; export {version as version_wasm} from './version'; diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts index 6d0aa6441cb..e9fc4ba9957 100644 --- a/tfjs-backend-wasm/src/kernels/Reverse.ts +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {ForwardFunc, registerKernel, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {ForwardFunc, KernelConfig, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; @@ -62,9 +62,9 @@ export function reverse( return reshape({inputs: {x: out}, attrs: {shape: x.shape}, backend}); } -registerKernel({ +export const reverseConfig: KernelConfig = { kernelName: Reverse, backendName: 'wasm', kernelFunc: reverse as ForwardFunc, setupFunc: setup -}); +}; diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts new file mode 100644 index 00000000000..87c151df493 --- /dev/null +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +// We explicitly import the modular kernels so they get registered in the +// global registry when we compile the library. A modular build would replace +// the contents of this file and import only the kernels that are needed. +import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; + +import {reverseConfig} from './kernels/Reverse'; + +// List all kernel configs here +const kernelConfigs: KernelConfig[] = [reverseConfig]; + +for (const kernelConfig of kernelConfigs) { + registerKernel(kernelConfig); +} From b018b67dd1e4a023d0c82edb36e148bd78aa5ebc Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 29 Jun 2020 17:03:36 -0400 Subject: [PATCH 11/13] fix --- tfjs-backend-wasm/src/kernels/Reverse.ts | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts index e9fc4ba9957..94c836d770b 100644 --- a/tfjs-backend-wasm/src/kernels/Reverse.ts +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {ForwardFunc, KernelConfig, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {KernelConfig, NamedAttrMap, NamedTensorInfoMap, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; @@ -36,12 +36,14 @@ function setup(backend: BackendWasm) { ]); } -export function reverse( - args: {inputs: ReverseInputs, backend: BackendWasm, attrs: ReverseAttrs}): - TensorInfo { +export function reverse(args: { + inputs: NamedTensorInfoMap, + backend: BackendWasm, + attrs: NamedAttrMap +}): TensorInfo { const {inputs, backend, attrs} = args; - const {x} = inputs; - const {dims} = attrs; + const {x} = inputs as {} as ReverseInputs; + const {dims} = attrs as {} as ReverseAttrs; const axes = util.parseAxisParam(dims, x.shape); @@ -65,6 +67,6 @@ export function reverse( export const reverseConfig: KernelConfig = { kernelName: Reverse, backendName: 'wasm', - kernelFunc: reverse as ForwardFunc, + kernelFunc: reverse, setupFunc: setup }; From ab011c1e035b2ab13a9bb1747325eba622250226 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 30 Jun 2020 12:33:14 -0400 Subject: [PATCH 12/13] type --- tfjs-backend-wasm/src/kernels/Reverse.ts | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts index 94c836d770b..6e2ed89c749 100644 --- a/tfjs-backend-wasm/src/kernels/Reverse.ts +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {KernelConfig, NamedAttrMap, NamedTensorInfoMap, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {KernelConfig, KernelFunc, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; @@ -36,14 +36,12 @@ function setup(backend: BackendWasm) { ]); } -export function reverse(args: { - inputs: NamedTensorInfoMap, - backend: BackendWasm, - attrs: NamedAttrMap -}): TensorInfo { +export function reverse( + args: {inputs: ReverseInputs, backend: BackendWasm, attrs: ReverseAttrs}): + TensorInfo { const {inputs, backend, attrs} = args; - const {x} = inputs as {} as ReverseInputs; - const {dims} = attrs as {} as ReverseAttrs; + const {x} = inputs; + const {dims} = attrs; const axes = util.parseAxisParam(dims, x.shape); @@ -67,6 +65,6 @@ export function reverse(args: { export const reverseConfig: KernelConfig = { kernelName: Reverse, backendName: 'wasm', - kernelFunc: reverse, + kernelFunc: reverse as {} as KernelFunc, setupFunc: setup }; From 607b5ef142ed7e292ae86508feec81ca51bf1252 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 6 Jul 2020 09:50:32 -0400 Subject: [PATCH 13/13] add clone --- tfjs-backend-wasm/src/backend_wasm.ts | 6 +--- tfjs-backend-wasm/src/kernels/Identity.ts | 37 ++++++++++++++++++++ tfjs-backend-wasm/src/kernels/Reverse.ts | 3 +- tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + tfjs-core/src/ops/clone.ts | 9 +++-- 5 files changed, 47 insertions(+), 9 deletions(-) create mode 100644 tfjs-backend-wasm/src/kernels/Identity.ts diff --git a/tfjs-backend-wasm/src/backend_wasm.ts b/tfjs-backend-wasm/src/backend_wasm.ts index 3ecdd8c4fde..0299073da66 100644 --- a/tfjs-backend-wasm/src/backend_wasm.ts +++ b/tfjs-backend-wasm/src/backend_wasm.ts @@ -16,7 +16,7 @@ */ import './flags_wasm'; -import {backend_util, BackendTimingInfo, DataStorage, DataType, engine, env, KernelBackend, registerBackend, Tensor, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, BackendTimingInfo, DataStorage, DataType, engine, env, KernelBackend, registerBackend, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasmModule, WasmFactoryConfig} from '../wasm-out/tfjs-backend-wasm'; import wasmFactorySimd from '../wasm-out/tfjs-backend-wasm-simd.js'; @@ -64,10 +64,6 @@ export class BackendWasm extends KernelBackend { return {kernelMs}; } - clone(x: TensorInfo): Tensor { - return engine().makeTensorFromDataId(x.dataId, x.shape, x.dtype); - } - move( dataId: DataId, values: backend_util.BackendValues, shape: number[], dtype: DataType): void { diff --git a/tfjs-backend-wasm/src/kernels/Identity.ts b/tfjs-backend-wasm/src/kernels/Identity.ts new file mode 100644 index 00000000000..edfdb824914 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Identity.ts @@ -0,0 +1,37 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Identity, IdentityInputs, KernelFunc, registerKernel} from '@tensorflow/tfjs-core'; +import {TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +export function identity(args: {inputs: IdentityInputs, backend: BackendWasm}): + TensorInfo { + const {inputs: {x}, backend} = args; + const out = backend.makeOutput(x.shape, x.dtype); + const inVals = backend.typedArrayFromHeap(x); + const outVals = backend.typedArrayFromHeap(out); + outVals.set(inVals); + return out; +} + +registerKernel({ + kernelName: Identity, + backendName: 'wasm', + kernelFunc: identity as {} as KernelFunc, +}); diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts index 6e2ed89c749..836ac304cb4 100644 --- a/tfjs-backend-wasm/src/kernels/Reverse.ts +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -19,6 +19,7 @@ import {KernelConfig, KernelFunc, Reverse, ReverseAttrs, ReverseInputs, TensorIn import {BackendWasm} from '../backend_wasm'; +import {identity} from './Identity'; import {reshape} from './Reshape'; let wasmReverse: ( @@ -46,7 +47,7 @@ export function reverse( const axes = util.parseAxisParam(dims, x.shape); if (x.shape.length === 0) { - return backend.clone(x); + return identity({inputs: {x}, backend}); } const out = backend.makeOutput(x.shape, x.dtype); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 1f304465401..c8cea76c002 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -45,6 +45,7 @@ import './Gather'; import './GatherNd'; import './Greater'; import './GreaterEqual'; +import './Identity'; import './Less'; import './LessEqual'; import './Log'; diff --git a/tfjs-core/src/ops/clone.ts b/tfjs-core/src/ops/clone.ts index 4cffbd9bae8..cfbd45003fe 100644 --- a/tfjs-core/src/ops/clone.ts +++ b/tfjs-core/src/ops/clone.ts @@ -16,9 +16,9 @@ */ import {ENGINE} from '../engine'; -import {Identity} from '../kernel_names'; - +import {Identity, IdentityInputs} from '../kernel_names'; import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; @@ -42,9 +42,12 @@ function clone_(x: T|TensorLike): T { const forward = () => ENGINE.makeTensorFromDataId($x.dataId, $x.shape, $x.dtype) as T; + const inputs: IdentityInputs = {x: $x}; + // Note this op is called tf.identity in python. Hence the kernel name used // here. - return ENGINE.runKernelFunc(forward, {x: $x}, null /* grad */, Identity); + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, Identity); } export const clone = op({clone_});