diff --git a/tfjs-backend-wasm/package.json b/tfjs-backend-wasm/package.json index 397df608df0..eadf8323b95 100644 --- a/tfjs-backend-wasm/package.json +++ b/tfjs-backend-wasm/package.json @@ -32,7 +32,7 @@ "path": false }, "peerDependencies": { - "@tensorflow/tfjs-core": "1.5.1" + "@tensorflow/tfjs-core": "link:../tfjs-core" }, "dependencies": { "@types/emscripten": "~0.0.34" @@ -40,7 +40,7 @@ "devDependencies": { "@bazel/bazel": "^0.28.0", "@bazel/buildifier": "0.29.0", - "@tensorflow/tfjs-core": "1.5.1", + "@tensorflow/tfjs-core": "link:../tfjs-core", "@types/jasmine": "~2.8.6", "clang-format": "~1.2.4", "jasmine": "~3.1.0", diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index efa73a29516..9ccd7ee682e 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -172,6 +172,7 @@ tfjs_cc_library( ":Relu", ":Relu6", ":ResizeBilinear", + ":ScatterNd", ":Sigmoid", ":Sub", ":Tile", @@ -586,6 +587,15 @@ tfjs_unit_test( ], ) +tfjs_cc_library( + name = "ScatterNd", + srcs = ["kernels/ScatterNd.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "Sigmoid", srcs = ["kernels/Sigmoid.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/ScatterNd.cc b/tfjs-backend-wasm/src/cc/kernels/ScatterNd.cc new file mode 100644 index 00000000000..d622a43d7fc --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/ScatterNd.cc @@ -0,0 +1,99 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +namespace { +template +void scatter(const int* indices_ptr, const T* updates_ptr, + const size_t slice_rank, const size_t num_updates, + const size_t slice_size, const std::vector& strides_ptr, + const size_t output_size, const size_t dtype_size, + T* out_buf_ptr) { + // Initialize output to 0. + memset(out_buf_ptr, 0, output_size * dtype_size); + + for (size_t i = 0; i < num_updates; ++i) { + size_t flattened_index = 0; + for (size_t j = 0; j < slice_rank; ++j) { + flattened_index += *indices_ptr * strides_ptr[j]; + + indices_ptr++; + } + + out_buf_ptr += flattened_index * slice_size; + + for (size_t k = 0; k < slice_size; ++k) { + *out_buf_ptr += *updates_ptr; + + out_buf_ptr++; + updates_ptr++; + } + + out_buf_ptr -= (flattened_index * slice_size + slice_size); + } +} + +} // namespace + +namespace tfjs { +namespace wasm { +extern "C" { +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +void ScatterNd(const size_t indices_id, const size_t updates_id, + const DType dtype, const size_t slice_rank, + const size_t num_updates, const size_t slice_size, + const size_t* strides_ptr, const size_t output_size, + const size_t out_id) { + auto& indices_info = backend::get_tensor_info(indices_id); + auto& updates_info = backend::get_tensor_info(updates_id); + const std::vector& strides = + std::vector(strides_ptr, strides_ptr + slice_rank); + const int* indices_buf = indices_info.i32(); + auto& out_info = backend::get_tensor_info_out(out_id); + + switch (dtype) { + case DType::float32: + scatter(indices_buf, updates_info.f32(), slice_rank, num_updates, + slice_size, strides, output_size, sizeof(float), + out_info.f32_write()); + break; + case DType::int32: + scatter(indices_buf, updates_info.i32(), slice_rank, num_updates, + slice_size, strides, output_size, sizeof(int32_t), + out_info.i32_write()); + break; + case DType::boolean: + scatter(indices_buf, updates_info.b(), slice_rank, num_updates, + slice_size, strides, output_size, sizeof(bool), + out_info.b_write()); + break; + default: + util::warn("Scatter for tensor id %d failed. Unknown dtype %d", + indices_id, dtype); + } +} +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/ClipByValue.ts b/tfjs-backend-wasm/src/kernels/ClipByValue.ts index 2e3f529addb..a666401d0a9 100644 --- a/tfjs-backend-wasm/src/kernels/ClipByValue.ts +++ b/tfjs-backend-wasm/src/kernels/ClipByValue.ts @@ -1,23 +1,6 @@ /** * @license - * Copyright 2019 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -/** - * @license - * Copyright 2019 Google Inc. All Rights Reserved. + * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at diff --git a/tfjs-backend-wasm/src/kernels/ScatterNd.ts b/tfjs-backend-wasm/src/kernels/ScatterNd.ts new file mode 100644 index 00000000000..d1e01be7435 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/ScatterNd.ts @@ -0,0 +1,88 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, scatter_util, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; +import {CppDType} from './types'; + +interface ScatterNdInputs extends NamedTensorInfoMap { + indices: TensorInfo; + updates: TensorInfo; +} + +interface ScatterNdAttrs extends NamedAttrMap { + shape: number[]; +} + +let wasmScatterNd: ( + indicesId: number, updatesId: number, dtype: CppDType, sliceRank: number, + numUpdates: number, sliceSize: number, strides: Uint8Array, + outputSize: number, outId: number) => void; + +function setup(backend: BackendWasm): void { + wasmScatterNd = backend.wasm.cwrap('ScatterNd', null /*void*/, [ + 'number', // indicesId + 'number', // updatesId + 'number', // dtype + 'number', // sliceRank + 'number', // numUpdates + 'number', // sliceSize + 'array', // strides + 'number', // outputSize + 'number' // outId + ]); +} + +function scatterNd( + args: + {backend: BackendWasm, inputs: ScatterNdInputs, attrs: ScatterNdAttrs}): + TensorInfo { + const {backend, inputs, attrs} = args; + const {indices, updates} = inputs; + const {shape} = attrs; + + const out = backend.makeOutput(shape, updates.dtype); + if (util.sizeFromShape(shape) === 0) { + return out; + } + + const {sliceRank, numUpdates, sliceSize, strides, outputSize} = + scatter_util.calculateShapes(updates, indices, shape); + + const indicesData = backend.dataIdMap.get(indices.dataId); + const indicesId = indicesData.id; + + const updatesData = backend.dataIdMap.get(updates.dataId); + const updatesId = updatesData.id; + + const stridesBytes = new Uint8Array(new Int32Array(strides).buffer); + + const outId = backend.dataIdMap.get(out.dataId).id; + wasmScatterNd( + indicesId, updatesId, CppDType[updates.dtype], sliceRank, numUpdates, + sliceSize, stridesBytes, outputSize, outId); + + return out; +} + +registerKernel({ + kernelName: 'ScatterNd', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: scatterNd +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 389be69b21b..b883514a4fc 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -57,6 +57,7 @@ import './Relu6'; import './Reshape'; import './ResizeBilinear'; import './Rsqrt'; +import './ScatterNd'; import './Sigmoid'; import './Sin'; import './Slice'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 12f7d05bb2a..faa2cb33338 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -125,6 +125,7 @@ const TEST_FILTERS: TestFilter[] = [ 'gradient' // Not yet implemented. ] }, + {include: 'scatterND '}, { include: 'abs ', excludes: [ diff --git a/tfjs-backend-wasm/yarn.lock b/tfjs-backend-wasm/yarn.lock index 8da2ee7c983..ac1064763a2 100644 --- a/tfjs-backend-wasm/yarn.lock +++ b/tfjs-backend-wasm/yarn.lock @@ -73,17 +73,9 @@ resolved "https://registry.yarnpkg.com/@bazel/hide-bazel-files/-/hide-bazel-files-0.38.3.tgz#e98231d3d360d51860d9c1a7c3345b40dab4cf81" integrity sha512-o+dNkfDm3qxWQ8h/04cWuTcjR7qnjZi3pQGv4aklVb16oPWx2jF8BzbkwvWuIkdbOl9VnqYP0vaHzwQVJRRcIA== -"@tensorflow/tfjs-core@1.5.1": - version "1.5.1" - resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.5.1.tgz#490209617f744fef660e8f81fe8b858e95b0d10b" - integrity sha512-N4fsi8mLsRwRs8UJN2cARB1rYFxyVXkLyZ4wOusiR976BwwZbCwQrTTSIPzPqYT3rwiexEUzm7sM6ZaDl5dpXA== - dependencies: - "@types/offscreencanvas" "~2019.3.0" - "@types/seedrandom" "2.4.27" - "@types/webgl-ext" "0.0.30" - "@types/webgl2" "0.0.4" - node-fetch "~2.1.2" - seedrandom "2.4.3" +"@tensorflow/tfjs-core@link:../tfjs-core": + version "0.0.0" + uid "" "@types/emscripten@~0.0.34": version "0.0.34" diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index c03e1591962..e6af207ac74 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -47,6 +47,7 @@ import * as backend_util from './backends/backend_util'; import * as io from './io/io'; import * as math from './math'; import * as browser from './ops/browser'; +import * as scatter_util from './ops/scatter_nd_util'; import * as slice_util from './ops/slice_util'; import * as serialization from './serialization'; import {setOpHandler} from './tensor'; @@ -99,7 +100,8 @@ export { backend_util, webgl, tensor_util, - slice_util + slice_util, + scatter_util }; // Backend specific. diff --git a/tfjs-core/src/ops/scatter_nd.ts b/tfjs-core/src/ops/scatter_nd.ts index 494ad992feb..a09bb02e7fd 100644 --- a/tfjs-core/src/ops/scatter_nd.ts +++ b/tfjs-core/src/ops/scatter_nd.ts @@ -49,7 +49,8 @@ function scatterND_( return ENGINE.runKernelFunc( backend => backend.scatterND($indices, $updates, shape), - {$indices, $updates}); + {indices: $indices, updates: $updates}, null /* backward */, 'ScatterNd', + {shape}); } export const scatterND = op({scatterND_}); diff --git a/tfjs-core/src/ops/scatter_nd_util.ts b/tfjs-core/src/ops/scatter_nd_util.ts index 3bb5b6155da..fd013f02b9c 100644 --- a/tfjs-core/src/ops/scatter_nd_util.ts +++ b/tfjs-core/src/ops/scatter_nd_util.ts @@ -14,6 +14,7 @@ * limitations under the License. * ============================================================================= */ +import {TensorInfo} from '../kernel_registry'; import {Tensor} from '../tensor'; import {computeStrides, sizeFromShape} from '../util'; @@ -123,9 +124,11 @@ export function validateInput( * @returns ScatterShapeInfo */ export function calculateShapes( - updates: Tensor, indices: Tensor, shape: number[]): ScatterShapeInfo { + updates: TensorInfo, indices: TensorInfo, + shape: number[]): ScatterShapeInfo { // Calculate the number of dimensions in indices - const sliceRank = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1; + const indicesRank = indices.shape.length; + const sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1; // Calculate the number of elements that make up each slice of our updated // tensor. This allows us to work with flattened tensors and copy over whole @@ -138,7 +141,7 @@ export function calculateShapes( } const safeSliceDim = (sliceRank < 1) ? 1 : sliceRank; - const numUpdates = indices.size / safeSliceDim; + const numUpdates = sizeFromShape(indices.shape) / safeSliceDim; const strides = [...computeStrides(shape.slice(0, sliceRank)), 1]; const outputSize = sizeFromShape(shape);