From 92c8c8dc1c8791dc2b5c13d24f14c64c587e24d2 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Thu, 12 Mar 2020 15:33:21 -0400 Subject: [PATCH 1/4] save --- tfjs-backend-wasm/package.json | 4 +-- tfjs-backend-wasm/src/kernels/OnesLike.ts | 38 ++++++++++++++++++++++ tfjs-backend-wasm/src/kernels/ZerosLike.ts | 38 ++++++++++++++++++++++ tfjs-backend-wasm/src/setup_test.ts | 3 +- tfjs-backend-wasm/yarn.lock | 14 ++------ tfjs-core/src/ops/tensor_ops.ts | 10 +++--- 6 files changed, 89 insertions(+), 18 deletions(-) create mode 100644 tfjs-backend-wasm/src/kernels/OnesLike.ts create mode 100644 tfjs-backend-wasm/src/kernels/ZerosLike.ts diff --git a/tfjs-backend-wasm/package.json b/tfjs-backend-wasm/package.json index 28ecffbf33b..e843f367892 100644 --- a/tfjs-backend-wasm/package.json +++ b/tfjs-backend-wasm/package.json @@ -32,7 +32,7 @@ "path": false }, "peerDependencies": { - "@tensorflow/tfjs-core": "1.7.0" + "@tensorflow/tfjs-core": "link:../tfjs-core" }, "dependencies": { "@types/emscripten": "~0.0.34" @@ -40,7 +40,7 @@ "devDependencies": { "@bazel/bazel": "^0.28.0", "@bazel/buildifier": "0.29.0", - "@tensorflow/tfjs-core": "1.7.0", + "@tensorflow/tfjs-core": "link:../tfjs-core", "@types/jasmine": "~2.8.6", "clang-format": "~1.2.4", "jasmine": "~3.1.0", diff --git a/tfjs-backend-wasm/src/kernels/OnesLike.ts b/tfjs-backend-wasm/src/kernels/OnesLike.ts new file mode 100644 index 00000000000..09340dbbd05 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/OnesLike.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelFunc, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface OnesLikeInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +function onesLike(args: {inputs: OnesLikeInputs, backend: BackendWasm}) { + const {inputs, backend} = args; + const out = backend.makeOutput(inputs.x.shape, inputs[0].dtype); + const outVals = backend.typedArrayFromHeap(out); + outVals.fill(1); + return out; +} + +registerKernel({ + kernelName: 'OnesLike', + backendName: 'wasm', + kernelFunc: onesLike as {} as KernelFunc, +}); diff --git a/tfjs-backend-wasm/src/kernels/ZerosLike.ts b/tfjs-backend-wasm/src/kernels/ZerosLike.ts new file mode 100644 index 00000000000..a8b42ea02a0 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/ZerosLike.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelFunc, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface ZerosLikeInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +function zerosLike(args: {inputs: ZerosLikeInputs, backend: BackendWasm}) { + const {inputs, backend} = args; + const out = backend.makeOutput(inputs.x.shape, inputs[0].dtype); + const outVals = backend.typedArrayFromHeap(out); + outVals.fill(0); + return out; +} + +registerKernel({ + kernelName: 'ZerosLike', + backendName: 'wasm', + kernelFunc: zerosLike as {} as KernelFunc, +}); diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 01d7c8bbc81..7212fecea09 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -326,7 +326,8 @@ const TEST_FILTERS: TestFilter[] = [ startsWith: 'rsqrt ', excludes: ['gradient'] // Gradient not yet implemented. }, - + {startsWith: 'zerosLike'}, + {startsWith: 'onesLike'}, ]; const customInclude = (testName: string) => { diff --git a/tfjs-backend-wasm/yarn.lock b/tfjs-backend-wasm/yarn.lock index 056632dd1fa..ac1064763a2 100644 --- a/tfjs-backend-wasm/yarn.lock +++ b/tfjs-backend-wasm/yarn.lock @@ -73,17 +73,9 @@ resolved "https://registry.yarnpkg.com/@bazel/hide-bazel-files/-/hide-bazel-files-0.38.3.tgz#e98231d3d360d51860d9c1a7c3345b40dab4cf81" integrity sha512-o+dNkfDm3qxWQ8h/04cWuTcjR7qnjZi3pQGv4aklVb16oPWx2jF8BzbkwvWuIkdbOl9VnqYP0vaHzwQVJRRcIA== -"@tensorflow/tfjs-core@1.7.0": - version "1.7.0" - resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.7.0.tgz#9207c8f2481c52a6a40135a6aaf21a9bb0339bdf" - integrity sha512-uwQdiklNjqBnHPeseOdG0sGxrI3+d6lybaKu2+ou3ajVeKdPEwpWbgqA6iHjq1iylnOGkgkbbnQ6r2lwkiIIHw== - dependencies: - "@types/offscreencanvas" "~2019.3.0" - "@types/seedrandom" "2.4.27" - "@types/webgl-ext" "0.0.30" - "@types/webgl2" "0.0.4" - node-fetch "~2.1.2" - seedrandom "2.4.3" +"@tensorflow/tfjs-core@link:../tfjs-core": + version "0.0.0" + uid "" "@types/emscripten@~0.0.34": version "0.0.34" diff --git a/tfjs-core/src/ops/tensor_ops.ts b/tfjs-core/src/ops/tensor_ops.ts index d9f80efec48..b3b9ad7ea6b 100644 --- a/tfjs-core/src/ops/tensor_ops.ts +++ b/tfjs-core/src/ops/tensor_ops.ts @@ -490,8 +490,9 @@ function onesLike_(x: T|TensorLike): T { const i = zerosLike(imag($x)); return complex(r, i); } - const der = (dy: T, saved: Tensor[]) => ({$x: () => zerosLike(dy)}); - return ENGINE.runKernelFunc(backend => backend.onesLike($x), {$x}, der) as T; + const der = (dy: T, saved: Tensor[]) => ({x: () => zerosLike(dy)}); + return ENGINE.runKernelFunc( + backend => backend.onesLike($x), {x: $x}, der, 'OnesLike') as T; } /** @@ -508,8 +509,9 @@ function onesLike_(x: T|TensorLike): T { /** @doc {heading: 'Tensors', subheading: 'Creation'} */ function zerosLike_(x: T|TensorLike): T { const $x = convertToTensor(x, 'x', 'zerosLike'); - const der = (dy: T, saved: Tensor[]) => ({$x: () => zerosLike(dy)}); - return ENGINE.runKernelFunc(backend => backend.zerosLike($x), {$x}, der) as T; + const der = (dy: T, saved: Tensor[]) => ({x: () => zerosLike(dy)}); + return ENGINE.runKernelFunc( + backend => backend.zerosLike($x), {x: $x}, der, 'ZerosLike') as T; } /** From f72b31c1e5841ba5c030aaa89b7c913e7eea088b Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Thu, 12 Mar 2020 15:36:10 -0400 Subject: [PATCH 2/4] save --- tfjs-backend-wasm/src/kernels/OnesLike.ts | 4 ++-- tfjs-backend-wasm/src/kernels/ZerosLike.ts | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/OnesLike.ts b/tfjs-backend-wasm/src/kernels/OnesLike.ts index 09340dbbd05..2965b3ac7cf 100644 --- a/tfjs-backend-wasm/src/kernels/OnesLike.ts +++ b/tfjs-backend-wasm/src/kernels/OnesLike.ts @@ -24,8 +24,8 @@ interface OnesLikeInputs extends NamedTensorInfoMap { } function onesLike(args: {inputs: OnesLikeInputs, backend: BackendWasm}) { - const {inputs, backend} = args; - const out = backend.makeOutput(inputs.x.shape, inputs[0].dtype); + const {inputs: {x}, backend} = args; + const out = backend.makeOutput(x.shape, x.dtype); const outVals = backend.typedArrayFromHeap(out); outVals.fill(1); return out; diff --git a/tfjs-backend-wasm/src/kernels/ZerosLike.ts b/tfjs-backend-wasm/src/kernels/ZerosLike.ts index a8b42ea02a0..8bfa0bee773 100644 --- a/tfjs-backend-wasm/src/kernels/ZerosLike.ts +++ b/tfjs-backend-wasm/src/kernels/ZerosLike.ts @@ -24,8 +24,8 @@ interface ZerosLikeInputs extends NamedTensorInfoMap { } function zerosLike(args: {inputs: ZerosLikeInputs, backend: BackendWasm}) { - const {inputs, backend} = args; - const out = backend.makeOutput(inputs.x.shape, inputs[0].dtype); + const {inputs: {x}, backend} = args; + const out = backend.makeOutput(x.shape, x.dtype); const outVals = backend.typedArrayFromHeap(out); outVals.fill(0); return out; From 7f22ac44be95b6d07b75f671f4f99514d778ac17 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Thu, 12 Mar 2020 15:39:23 -0400 Subject: [PATCH 3/4] save --- tfjs-backend-wasm/src/setup_test.ts | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 7212fecea09..7564fc6899b 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -326,8 +326,16 @@ const TEST_FILTERS: TestFilter[] = [ startsWith: 'rsqrt ', excludes: ['gradient'] // Gradient not yet implemented. }, - {startsWith: 'zerosLike'}, - {startsWith: 'onesLike'}, + { + startsWith: 'zerosLike', + // Complex numbers not supported yet. + excludes: ['complex'], + }, + { + startsWith: 'onesLike', + // Complex numbers not supported yet. + excludes: ['complex'], + }, ]; const customInclude = (testName: string) => { From a2a475c7b2d9d72d5600f6f3ab71e1111b072029 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Thu, 12 Mar 2020 16:12:20 -0400 Subject: [PATCH 4/4] save --- tfjs-backend-wasm/src/kernels/all_kernels.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 56cec9e79f0..454fc89d85b 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -29,8 +29,8 @@ import './Cast'; import './ClipByValue'; import './Concat'; import './Conv2D'; -import './CropAndResize'; import './Cos'; +import './CropAndResize'; import './DepthwiseConv2dNative'; import './Div'; import './Exp'; @@ -42,10 +42,10 @@ import './Gather'; import './GatherNd'; import './Greater'; import './GreaterEqual'; -import './LogicalAnd'; import './Less'; import './LessEqual'; import './Log'; +import './LogicalAnd'; import './Max'; import './Maximum'; import './MaxPool'; @@ -56,6 +56,7 @@ import './Neg'; import './NonMaxSuppressionV3'; import './NonMaxSuppressionV5'; import './NotEqual'; +import './OnesLike'; import './PadV2'; import './Pow'; import './Prelu'; @@ -76,3 +77,4 @@ import './Tanh'; import './Tile'; import './Transpose'; import './Unpack'; +import './ZerosLike';