diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 9ccd7ee682e..4c8d8857eb0 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -156,6 +156,8 @@ tfjs_cc_library( ":FusedBatchNorm", ":FusedConv2D", ":FusedDepthwiseConv2D", + ":Gather", + ":GatherNd", ":Greater", ":GreaterEqual", ":Less", @@ -387,6 +389,24 @@ tfjs_unit_test( ], ) +tfjs_cc_library( + name = "Gather", + srcs = ["kernels/Gather.cc"], + deps = [ + ":backend", + ":util", + ], +) + +tfjs_cc_library( + name = "GatherNd", + srcs = ["kernels/GatherNd.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "Greater", srcs = ["kernels/Greater.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Gather.cc b/tfjs-backend-wasm/src/cc/kernels/Gather.cc new file mode 100644 index 00000000000..a8510d0813c --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Gather.cc @@ -0,0 +1,86 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +namespace { + +template +void gather_impl(const T* x_ptr, const std::vector& x_strides, + const int32_t* indices_ptr, const size_t axis, + const size_t out_size, const std::vector& out_strides, + T* out_buf_ptr) { + for (size_t i = 0; i < out_size; ++i) { + auto loc = tfjs::util::offset_to_loc(i, out_strides); + const size_t new_loc = loc[axis]; + loc[axis] = indices_ptr[new_loc]; + + const size_t original_index = tfjs::util::loc_to_offset(loc, x_strides); + + *out_buf_ptr = x_ptr[original_index]; + + out_buf_ptr++; + } +} +} // namespace + +namespace tfjs { +namespace wasm { +extern "C" { +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +void Gather(const size_t x_id, const DType dtype, const int32_t* x_strides_ptr, + const size_t strides_size, const size_t indices_id, + const size_t axis, const int32_t* out_strides_ptr, + const size_t out_id) { + auto& x_info = backend::get_tensor_info(x_id); + auto& indices_info = backend::get_tensor_info(indices_id); + + const int* indices_buf = indices_info.i32(); + auto& out_info = backend::get_tensor_info_out(out_id); + const size_t out_size = out_info.size; + + const auto x_strides = + std::vector(x_strides_ptr, x_strides_ptr + strides_size); + const auto out_strides = + std::vector(out_strides_ptr, out_strides_ptr + strides_size); + + switch (dtype) { + case DType::float32: + gather_impl(x_info.f32(), x_strides, indices_buf, axis, out_size, + out_strides, out_info.f32_write()); + break; + case DType::int32: + gather_impl(x_info.i32(), x_strides, indices_buf, axis, out_size, + out_strides, out_info.i32_write()); + break; + case DType::boolean: + gather_impl(x_info.b(), x_strides, indices_buf, axis, out_size, + out_strides, out_info.b_write()); + break; + default: + util::warn("Gather for tensor id %d failed. Unknown dtype %d", x_id, + dtype); + } +} +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/GatherNd.cc b/tfjs-backend-wasm/src/cc/kernels/GatherNd.cc new file mode 100644 index 00000000000..6d9b1a77ed3 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/GatherNd.cc @@ -0,0 +1,90 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +namespace { + +template +void gathernd_impl(const T* x_ptr, const int32_t* indices_ptr, + const size_t num_slices, const size_t slice_rank, + const size_t slice_size, + const std::vector& strides_ptr, T* out_buf_ptr) { + for (size_t i = 0; i < num_slices; ++i) { + size_t flattened_index = 0; + for (size_t j = 0; j < slice_rank; ++j) { + flattened_index += (*indices_ptr * strides_ptr[j]); + + indices_ptr++; + } + + x_ptr += flattened_index * slice_size; + + for (size_t k = 0; k < slice_size; ++k) { + *out_buf_ptr = *x_ptr; + + out_buf_ptr++; + x_ptr++; + } + + x_ptr -= ((flattened_index + 1) * slice_size); + } +} +} // namespace + +namespace tfjs { +namespace wasm { +extern "C" { +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +void GatherNd(const size_t x_id, const DType dtype, const size_t indices_id, + const size_t num_slices, const size_t slice_rank, + const size_t slice_size, const int32_t* strides_ptr, + const size_t out_id) { + auto& x_info = backend::get_tensor_info(x_id); + auto& indices_info = backend::get_tensor_info(indices_id); + const std::vector& strides = + std::vector(strides_ptr, strides_ptr + slice_rank); + + const int* indices_buf = indices_info.i32(); + auto& out_info = backend::get_tensor_info_out(out_id); + + switch (dtype) { + case DType::float32: + gathernd_impl(x_info.f32(), indices_buf, num_slices, slice_rank, + slice_size, strides, out_info.f32_write()); + break; + case DType::int32: + gathernd_impl(x_info.i32(), indices_buf, num_slices, slice_rank, + slice_size, strides, out_info.i32_write()); + break; + case DType::boolean: + gathernd_impl(x_info.b(), indices_buf, num_slices, slice_rank, + slice_size, strides, out_info.b_write()); + break; + default: + util::warn("GatherNd for tensor id %d failed. Unknown dtype %d", + indices_id, dtype); + } +} +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Gather.ts b/tfjs-backend-wasm/src/kernels/Gather.ts new file mode 100644 index 00000000000..25479954af5 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Gather.ts @@ -0,0 +1,91 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; +import {CppDType} from './types'; + +interface GatherInputs extends NamedTensorInfoMap { + x: TensorInfo; + indices: TensorInfo; +} + +interface GatherAttrs extends NamedAttrMap { + axis: number; +} + +let wasmGather: + (xId: number, dtype: CppDType, xStrides: Uint8Array, stridesSize: number, + indicesId: number, axis: number, outStrides: Uint8Array, outId: number) => + void; + +function setup(backend: BackendWasm): void { + wasmGather = backend.wasm.cwrap('Gather', null /*void*/, [ + 'number', // xId + 'number', // dtype + 'array', // xStrides + 'number', // stridesSize + 'number', // indicesId + 'number', // axis + 'array', // outStrides + 'number' // outId + ]); +} + +function gather( + args: {backend: BackendWasm, inputs: GatherInputs, attrs: GatherAttrs}): + TensorInfo { + const {backend, inputs, attrs} = args; + const {x, indices} = inputs; + const {axis} = attrs; + + const newShape = x.shape.slice(); + newShape[axis] = util.sizeFromShape(indices.shape); + const stridesSize = x.shape.length - 1; + + const out = backend.makeOutput(newShape, x.dtype); + if (util.sizeFromShape(x.shape) === 0) { + return out; + } + + const xData = backend.dataIdMap.get(x.dataId); + const xId = xData.id; + + const indicesData = backend.dataIdMap.get(indices.dataId); + const indicesId = indicesData.id; + + const outId = backend.dataIdMap.get(out.dataId).id; + + const xStridesBytes = + new Uint8Array(new Int32Array(util.computeStrides(x.shape)).buffer); + const outStridesBytes = + new Uint8Array(new Int32Array(util.computeStrides(newShape)).buffer); + + wasmGather( + xId, CppDType[x.dtype], xStridesBytes, stridesSize, indicesId, axis, + outStridesBytes, outId); + + return out; +} + +registerKernel({ + kernelName: 'Gather', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: gather +}); diff --git a/tfjs-backend-wasm/src/kernels/GatherNd.ts b/tfjs-backend-wasm/src/kernels/GatherNd.ts new file mode 100644 index 00000000000..d3fc9b13f9d --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/GatherNd.ts @@ -0,0 +1,82 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {gather_util, NamedTensorInfoMap, registerKernel, Tensor, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; +import {CppDType} from './types'; + +interface GatherNdInputs extends NamedTensorInfoMap { + x: TensorInfo; + indices: TensorInfo; +} + +let wasmGatherNd: ( + xId: number, dtype: CppDType, indicesId: number, numSlices: number, + sliceRank: number, sliceSize: number, strides: Uint8Array, outId: number) => + void; + +function setup(backend: BackendWasm): void { + wasmGatherNd = backend.wasm.cwrap('GatherNd', null /*void*/, [ + 'number', // xId + 'number', // dtype + 'number', // indicesId + 'number', // numSlices + 'number', // sliceRank + 'number', // sliceSize + 'array', // strides + 'number' // outId + ]); +} + +function gatherNd(args: {backend: BackendWasm, inputs: GatherNdInputs}): + TensorInfo { + const {backend, inputs} = args; + const {x, indices} = inputs; + + const [resultShape, numSlices, sliceSize, strides] = + gather_util.prepareAndValidate(x as Tensor, indices as Tensor); + + const out = backend.makeOutput(resultShape, x.dtype); + if (numSlices === 0) { + return out; + } + + const indicesShape = indices.shape; + const sliceRank = indicesShape[indicesShape.length - 1]; + + const xData = backend.dataIdMap.get(x.dataId); + const xId = xData.id; + const indicesData = backend.dataIdMap.get(indices.dataId); + const indicesId = indicesData.id; + + const stridesBytes = new Uint8Array(new Int32Array(strides).buffer); + + const outId = backend.dataIdMap.get(out.dataId).id; + wasmGatherNd( + xId, CppDType[x.dtype], indicesId, numSlices, sliceRank, sliceSize, + stridesBytes, outId); + + return out; +} + +registerKernel({ + kernelName: 'GatherNd', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: gatherNd +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index b883514a4fc..0a47983ee08 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -37,6 +37,8 @@ import './FloorDiv'; import './FusedBatchNorm'; import './FusedConv2D'; import './FusedDepthwiseConv2D'; +import './Gather'; +import './GatherNd'; import './Greater'; import './GreaterEqual'; import './LogicalAnd'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index faa2cb33338..005cb3a7e86 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -118,6 +118,12 @@ const TEST_FILTERS: TestFilter[] = [ 'shallow slice an input that was cast' // Slice is not implemented. ] }, + { + include: 'gather', + excludes: [ + 'gradient' // Not yet implemented. + ] + }, { include: 'sigmoid ', excludes: [ diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index e6af207ac74..ffb2638caa7 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -29,7 +29,6 @@ import './engine'; // Register backend-agnostic flags. import './flags'; - // backend_cpu.ts and backend_webgl.ts are standalone files and should be // explicitly included here. import './backends/webgl/backend_webgl'; @@ -38,7 +37,6 @@ import './backends/cpu/backend_cpu'; import './backends/cpu/all_kernels'; // Import all kernels from webgl. import './backends/webgl/all_kernels'; - import './platforms/platform_browser'; import './platforms/platform_node'; @@ -47,6 +45,7 @@ import * as backend_util from './backends/backend_util'; import * as io from './io/io'; import * as math from './math'; import * as browser from './ops/browser'; +import * as gather_util from './ops/gather_nd_util'; import * as scatter_util from './ops/scatter_nd_util'; import * as slice_util from './ops/slice_util'; import * as serialization from './serialization'; @@ -101,6 +100,7 @@ export { webgl, tensor_util, slice_util, + gather_util, scatter_util }; diff --git a/tfjs-core/src/ops/gather_nd.ts b/tfjs-core/src/ops/gather_nd.ts index a817963f7da..521396856f8 100644 --- a/tfjs-core/src/ops/gather_nd.ts +++ b/tfjs-core/src/ops/gather_nd.ts @@ -61,6 +61,7 @@ function gatherND_(x: Tensor|TensorLike, indices: Tensor|TensorLike): Tensor { const $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32'); const $x = convertToTensor(x, 'x', 'gatherND'); return ENGINE.runKernelFunc( - backend => backend.gatherND($x, $indices), {$x, $indices}); + backend => backend.gatherND($x, $indices), {x: $x, indices: $indices}, + null /* backward */, 'GatherNd'); } export const gatherND = op({gatherND_}); diff --git a/tfjs-core/src/ops/segment_ops.ts b/tfjs-core/src/ops/segment_ops.ts index 4747452fd42..1a00f6a7a3f 100644 --- a/tfjs-core/src/ops/segment_ops.ts +++ b/tfjs-core/src/ops/segment_ops.ts @@ -126,13 +126,16 @@ function gather_( return paramsGrad as T; }; - return {$x: derX}; + return {x: derX, indices: () => $indices}; }; - return (ENGINE.runKernelFunc((backend, save) => { - const res = backend.gather($x, $indices.flatten(), axis); - save([$indices]); - return res; - }, {$x}, grad)).reshape(shapeInfo.outputShape) as T; + return (ENGINE.runKernelFunc( + (backend, save) => { + const res = backend.gather($x, $indices.flatten(), axis); + save([$indices]); + return res; + }, + {x: $x, indices: $indices}, grad, 'Gather', {axis})) + .reshape(shapeInfo.outputShape) as T; } function arrayRange(start: number, stop: number): number[] {