diff --git a/tfjs-backend-wasm/package.json b/tfjs-backend-wasm/package.json index 28ecffbf33b..e843f367892 100644 --- a/tfjs-backend-wasm/package.json +++ b/tfjs-backend-wasm/package.json @@ -32,7 +32,7 @@ "path": false }, "peerDependencies": { - "@tensorflow/tfjs-core": "1.7.0" + "@tensorflow/tfjs-core": "link:../tfjs-core" }, "dependencies": { "@types/emscripten": "~0.0.34" @@ -40,7 +40,7 @@ "devDependencies": { "@bazel/bazel": "^0.28.0", "@bazel/buildifier": "0.29.0", - "@tensorflow/tfjs-core": "1.7.0", + "@tensorflow/tfjs-core": "link:../tfjs-core", "@types/jasmine": "~2.8.6", "clang-format": "~1.2.4", "jasmine": "~3.1.0", diff --git a/tfjs-backend-wasm/src/kernels/OnesLike.ts b/tfjs-backend-wasm/src/kernels/OnesLike.ts new file mode 100644 index 00000000000..2965b3ac7cf --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/OnesLike.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelFunc, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface OnesLikeInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +function onesLike(args: {inputs: OnesLikeInputs, backend: BackendWasm}) { + const {inputs: {x}, backend} = args; + const out = backend.makeOutput(x.shape, x.dtype); + const outVals = backend.typedArrayFromHeap(out); + outVals.fill(1); + return out; +} + +registerKernel({ + kernelName: 'OnesLike', + backendName: 'wasm', + kernelFunc: onesLike as {} as KernelFunc, +}); diff --git a/tfjs-backend-wasm/src/kernels/ZerosLike.ts b/tfjs-backend-wasm/src/kernels/ZerosLike.ts new file mode 100644 index 00000000000..8bfa0bee773 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/ZerosLike.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelFunc, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface ZerosLikeInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +function zerosLike(args: {inputs: ZerosLikeInputs, backend: BackendWasm}) { + const {inputs: {x}, backend} = args; + const out = backend.makeOutput(x.shape, x.dtype); + const outVals = backend.typedArrayFromHeap(out); + outVals.fill(0); + return out; +} + +registerKernel({ + kernelName: 'ZerosLike', + backendName: 'wasm', + kernelFunc: zerosLike as {} as KernelFunc, +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 56cec9e79f0..454fc89d85b 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -29,8 +29,8 @@ import './Cast'; import './ClipByValue'; import './Concat'; import './Conv2D'; -import './CropAndResize'; import './Cos'; +import './CropAndResize'; import './DepthwiseConv2dNative'; import './Div'; import './Exp'; @@ -42,10 +42,10 @@ import './Gather'; import './GatherNd'; import './Greater'; import './GreaterEqual'; -import './LogicalAnd'; import './Less'; import './LessEqual'; import './Log'; +import './LogicalAnd'; import './Max'; import './Maximum'; import './MaxPool'; @@ -56,6 +56,7 @@ import './Neg'; import './NonMaxSuppressionV3'; import './NonMaxSuppressionV5'; import './NotEqual'; +import './OnesLike'; import './PadV2'; import './Pow'; import './Prelu'; @@ -76,3 +77,4 @@ import './Tanh'; import './Tile'; import './Transpose'; import './Unpack'; +import './ZerosLike'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 01d7c8bbc81..7564fc6899b 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -326,7 +326,16 @@ const TEST_FILTERS: TestFilter[] = [ startsWith: 'rsqrt ', excludes: ['gradient'] // Gradient not yet implemented. }, - + { + startsWith: 'zerosLike', + // Complex numbers not supported yet. + excludes: ['complex'], + }, + { + startsWith: 'onesLike', + // Complex numbers not supported yet. + excludes: ['complex'], + }, ]; const customInclude = (testName: string) => { diff --git a/tfjs-backend-wasm/yarn.lock b/tfjs-backend-wasm/yarn.lock index 056632dd1fa..ac1064763a2 100644 --- a/tfjs-backend-wasm/yarn.lock +++ b/tfjs-backend-wasm/yarn.lock @@ -73,17 +73,9 @@ resolved "https://registry.yarnpkg.com/@bazel/hide-bazel-files/-/hide-bazel-files-0.38.3.tgz#e98231d3d360d51860d9c1a7c3345b40dab4cf81" integrity sha512-o+dNkfDm3qxWQ8h/04cWuTcjR7qnjZi3pQGv4aklVb16oPWx2jF8BzbkwvWuIkdbOl9VnqYP0vaHzwQVJRRcIA== -"@tensorflow/tfjs-core@1.7.0": - version "1.7.0" - resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.7.0.tgz#9207c8f2481c52a6a40135a6aaf21a9bb0339bdf" - integrity sha512-uwQdiklNjqBnHPeseOdG0sGxrI3+d6lybaKu2+ou3ajVeKdPEwpWbgqA6iHjq1iylnOGkgkbbnQ6r2lwkiIIHw== - dependencies: - "@types/offscreencanvas" "~2019.3.0" - "@types/seedrandom" "2.4.27" - "@types/webgl-ext" "0.0.30" - "@types/webgl2" "0.0.4" - node-fetch "~2.1.2" - seedrandom "2.4.3" +"@tensorflow/tfjs-core@link:../tfjs-core": + version "0.0.0" + uid "" "@types/emscripten@~0.0.34": version "0.0.34" diff --git a/tfjs-core/src/ops/tensor_ops.ts b/tfjs-core/src/ops/tensor_ops.ts index d9f80efec48..b3b9ad7ea6b 100644 --- a/tfjs-core/src/ops/tensor_ops.ts +++ b/tfjs-core/src/ops/tensor_ops.ts @@ -490,8 +490,9 @@ function onesLike_(x: T|TensorLike): T { const i = zerosLike(imag($x)); return complex(r, i); } - const der = (dy: T, saved: Tensor[]) => ({$x: () => zerosLike(dy)}); - return ENGINE.runKernelFunc(backend => backend.onesLike($x), {$x}, der) as T; + const der = (dy: T, saved: Tensor[]) => ({x: () => zerosLike(dy)}); + return ENGINE.runKernelFunc( + backend => backend.onesLike($x), {x: $x}, der, 'OnesLike') as T; } /** @@ -508,8 +509,9 @@ function onesLike_(x: T|TensorLike): T { /** @doc {heading: 'Tensors', subheading: 'Creation'} */ function zerosLike_(x: T|TensorLike): T { const $x = convertToTensor(x, 'x', 'zerosLike'); - const der = (dy: T, saved: Tensor[]) => ({$x: () => zerosLike(dy)}); - return ENGINE.runKernelFunc(backend => backend.zerosLike($x), {$x}, der) as T; + const der = (dy: T, saved: Tensor[]) => ({x: () => zerosLike(dy)}); + return ENGINE.runKernelFunc( + backend => backend.zerosLike($x), {x: $x}, der, 'ZerosLike') as T; } /**