diff --git a/tfjs-backend-wasm/src/kernels/GatherNd.ts b/tfjs-backend-wasm/src/kernels/GatherNd.ts index d3fc9b13f9d..bb197289087 100644 --- a/tfjs-backend-wasm/src/kernels/GatherNd.ts +++ b/tfjs-backend-wasm/src/kernels/GatherNd.ts @@ -15,15 +15,11 @@ * ============================================================================= */ -import {gather_util, NamedTensorInfoMap, registerKernel, Tensor, TensorInfo} from '@tensorflow/tfjs-core'; +import {gather_util, GatherNd, GatherNdInputs, registerKernel, Tensor, TensorInfo} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; -import {CppDType} from './types'; -interface GatherNdInputs extends NamedTensorInfoMap { - x: TensorInfo; - indices: TensorInfo; -} +import {CppDType} from './types'; let wasmGatherNd: ( xId: number, dtype: CppDType, indicesId: number, numSlices: number, @@ -46,12 +42,12 @@ function setup(backend: BackendWasm): void { function gatherNd(args: {backend: BackendWasm, inputs: GatherNdInputs}): TensorInfo { const {backend, inputs} = args; - const {x, indices} = inputs; + const {params, indices} = inputs; const [resultShape, numSlices, sliceSize, strides] = - gather_util.prepareAndValidate(x as Tensor, indices as Tensor); + gather_util.prepareAndValidate(params as Tensor, indices as Tensor); - const out = backend.makeOutput(resultShape, x.dtype); + const out = backend.makeOutput(resultShape, params.dtype); if (numSlices === 0) { return out; } @@ -59,7 +55,7 @@ function gatherNd(args: {backend: BackendWasm, inputs: GatherNdInputs}): const indicesShape = indices.shape; const sliceRank = indicesShape[indicesShape.length - 1]; - const xData = backend.dataIdMap.get(x.dataId); + const xData = backend.dataIdMap.get(params.dataId); const xId = xData.id; const indicesData = backend.dataIdMap.get(indices.dataId); const indicesId = indicesData.id; @@ -68,14 +64,14 @@ function gatherNd(args: {backend: BackendWasm, inputs: GatherNdInputs}): const outId = backend.dataIdMap.get(out.dataId).id; wasmGatherNd( - xId, CppDType[x.dtype], indicesId, numSlices, sliceRank, sliceSize, + xId, CppDType[params.dtype], indicesId, numSlices, sliceRank, sliceSize, stridesBytes, outId); return out; } registerKernel({ - kernelName: 'GatherNd', + kernelName: GatherNd, backendName: 'wasm', setupFunc: setup, kernelFunc: gatherNd diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 88890f7a5e2..c9ce90b2320 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -200,6 +200,9 @@ export interface FusedBatchNormAttrs { varianceEpsilon: number; } +export const GatherNd = 'GatherNd'; +export type GatherNdInputs = Pick; + export const Greater = 'Greater'; export type GreaterInputs = BinaryInputs; diff --git a/tfjs-core/src/ops/gather_nd.ts b/tfjs-core/src/ops/gather_nd.ts index 521396856f8..d4e37404e56 100644 --- a/tfjs-core/src/ops/gather_nd.ts +++ b/tfjs-core/src/ops/gather_nd.ts @@ -14,8 +14,10 @@ * limitations under the License. * ============================================================================= */ -import {ENGINE} from '../engine'; +import {ENGINE, ForwardFunc} from '../engine'; +import {GatherNd, GatherNdInputs} from '../kernel_names'; import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {op} from './operation'; @@ -60,8 +62,15 @@ import {op} from './operation'; function gatherND_(x: Tensor|TensorLike, indices: Tensor|TensorLike): Tensor { const $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32'); const $x = convertToTensor(x, 'x', 'gatherND'); + + const forward: ForwardFunc = (backend) => { + return backend.gatherND($x, $indices); + }; + + const inputs: GatherNdInputs = {params: $x, indices: $indices}; + return ENGINE.runKernelFunc( - backend => backend.gatherND($x, $indices), {x: $x, indices: $indices}, - null /* backward */, 'GatherNd'); + forward, inputs as {} as NamedTensorMap, null /* gradient */, GatherNd); } + export const gatherND = op({gatherND_});