Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions tfjs-backend-wasm/src/kernels/GatherNd.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@
* =============================================================================
*/

import {gather_util, NamedTensorInfoMap, registerKernel, Tensor, TensorInfo} from '@tensorflow/tfjs-core';
import {gather_util, GatherNd, GatherNdInputs, registerKernel, Tensor, TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';
import {CppDType} from './types';

interface GatherNdInputs extends NamedTensorInfoMap {
x: TensorInfo;
indices: TensorInfo;
}
import {CppDType} from './types';

let wasmGatherNd: (
xId: number, dtype: CppDType, indicesId: number, numSlices: number,
Expand All @@ -46,20 +42,20 @@ function setup(backend: BackendWasm): void {
function gatherNd(args: {backend: BackendWasm, inputs: GatherNdInputs}):
TensorInfo {
const {backend, inputs} = args;
const {x, indices} = inputs;
const {params, indices} = inputs;

const [resultShape, numSlices, sliceSize, strides] =
gather_util.prepareAndValidate(x as Tensor, indices as Tensor);
gather_util.prepareAndValidate(params as Tensor, indices as Tensor);

const out = backend.makeOutput(resultShape, x.dtype);
const out = backend.makeOutput(resultShape, params.dtype);
if (numSlices === 0) {
return out;
}

const indicesShape = indices.shape;
const sliceRank = indicesShape[indicesShape.length - 1];

const xData = backend.dataIdMap.get(x.dataId);
const xData = backend.dataIdMap.get(params.dataId);
const xId = xData.id;
const indicesData = backend.dataIdMap.get(indices.dataId);
const indicesId = indicesData.id;
Expand All @@ -68,14 +64,14 @@ function gatherNd(args: {backend: BackendWasm, inputs: GatherNdInputs}):

const outId = backend.dataIdMap.get(out.dataId).id;
wasmGatherNd(
xId, CppDType[x.dtype], indicesId, numSlices, sliceRank, sliceSize,
xId, CppDType[params.dtype], indicesId, numSlices, sliceRank, sliceSize,
stridesBytes, outId);

return out;
}

registerKernel({
kernelName: 'GatherNd',
kernelName: GatherNd,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: gatherNd
Expand Down
3 changes: 3 additions & 0 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ export interface FusedBatchNormAttrs {
varianceEpsilon: number;
}

export const GatherNd = 'GatherNd';
export type GatherNdInputs = Pick<NamedTensorInfoMap, 'params'|'indices'>;

export const Greater = 'Greater';
export type GreaterInputs = BinaryInputs;

Expand Down
15 changes: 12 additions & 3 deletions tfjs-core/src/ops/gather_nd.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
* limitations under the License.
* =============================================================================
*/
import {ENGINE} from '../engine';
import {ENGINE, ForwardFunc} from '../engine';
import {GatherNd, GatherNdInputs} from '../kernel_names';
import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import {op} from './operation';
Expand Down Expand Up @@ -60,8 +62,15 @@ import {op} from './operation';
function gatherND_(x: Tensor|TensorLike, indices: Tensor|TensorLike): Tensor {
const $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32');
const $x = convertToTensor(x, 'x', 'gatherND');

const forward: ForwardFunc<Tensor> = (backend) => {
return backend.gatherND($x, $indices);
};

const inputs: GatherNdInputs = {params: $x, indices: $indices};

return ENGINE.runKernelFunc(
backend => backend.gatherND($x, $indices), {x: $x, indices: $indices},
null /* backward */, 'GatherNd');
forward, inputs as {} as NamedTensorMap, null /* gradient */, GatherNd);
}

export const gatherND = op({gatherND_});