From da57176b67c8b5350787ae93cbb9034a393b85cb Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Thu, 4 Jun 2020 10:36:35 -0400 Subject: [PATCH 1/2] modularize GatherND --- tfjs-core/src/kernel_names.ts | 3 +++ tfjs-core/src/ops/gather_nd.ts | 15 ++++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 4d92b597b71..1d7f75929be 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -197,6 +197,9 @@ export interface FusedBatchNormAttrs { varianceEpsilon: number; } +export const GatherND = 'GatherND'; +export type GatherNDInputs = Pick; + export const Greater = 'Greater'; export type GreaterInputs = BinaryInputs; diff --git a/tfjs-core/src/ops/gather_nd.ts b/tfjs-core/src/ops/gather_nd.ts index 521396856f8..197702c95e7 100644 --- a/tfjs-core/src/ops/gather_nd.ts +++ b/tfjs-core/src/ops/gather_nd.ts @@ -14,8 +14,10 @@ * limitations under the License. * ============================================================================= */ -import {ENGINE} from '../engine'; +import {ENGINE, ForwardFunc} from '../engine'; +import {GatherND, GatherNDInputs} from '../kernel_names'; import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {op} from './operation'; @@ -60,8 +62,15 @@ import {op} from './operation'; function gatherND_(x: Tensor|TensorLike, indices: Tensor|TensorLike): Tensor { const $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32'); const $x = convertToTensor(x, 'x', 'gatherND'); + + const forward: ForwardFunc = (backend) => { + return backend.gatherND($x, $indices); + }; + + const inputs: GatherNDInputs = {params: $x, indices: $indices}; + return ENGINE.runKernelFunc( - backend => backend.gatherND($x, $indices), {x: $x, indices: $indices}, - null /* backward */, 'GatherNd'); + forward, inputs as {} as NamedTensorMap, null /* gradient */, GatherND); } + export const gatherND = op({gatherND_}); From 51a06ed64dab0c0c58269330f024d3167b7babcf Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Thu, 4 Jun 2020 20:39:35 -0400 Subject: [PATCH 2/2] update gathernd in wasm. fix spelling --- tfjs-backend-wasm/src/kernels/GatherNd.ts | 20 ++++++++------------ tfjs-core/src/kernel_names.ts | 4 ++-- tfjs-core/src/ops/gather_nd.ts | 6 +++--- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/GatherNd.ts b/tfjs-backend-wasm/src/kernels/GatherNd.ts index d3fc9b13f9d..bb197289087 100644 --- a/tfjs-backend-wasm/src/kernels/GatherNd.ts +++ b/tfjs-backend-wasm/src/kernels/GatherNd.ts @@ -15,15 +15,11 @@ * ============================================================================= */ -import {gather_util, NamedTensorInfoMap, registerKernel, Tensor, TensorInfo} from '@tensorflow/tfjs-core'; +import {gather_util, GatherNd, GatherNdInputs, registerKernel, Tensor, TensorInfo} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; -import {CppDType} from './types'; -interface GatherNdInputs extends NamedTensorInfoMap { - x: TensorInfo; - indices: TensorInfo; -} +import {CppDType} from './types'; let wasmGatherNd: ( xId: number, dtype: CppDType, indicesId: number, numSlices: number, @@ -46,12 +42,12 @@ function setup(backend: BackendWasm): void { function gatherNd(args: {backend: BackendWasm, inputs: GatherNdInputs}): TensorInfo { const {backend, inputs} = args; - const {x, indices} = inputs; + const {params, indices} = inputs; const [resultShape, numSlices, sliceSize, strides] = - gather_util.prepareAndValidate(x as Tensor, indices as Tensor); + gather_util.prepareAndValidate(params as Tensor, indices as Tensor); - const out = backend.makeOutput(resultShape, x.dtype); + const out = backend.makeOutput(resultShape, params.dtype); if (numSlices === 0) { return out; } @@ -59,7 +55,7 @@ function gatherNd(args: {backend: BackendWasm, inputs: GatherNdInputs}): const indicesShape = indices.shape; const sliceRank = indicesShape[indicesShape.length - 1]; - const xData = backend.dataIdMap.get(x.dataId); + const xData = backend.dataIdMap.get(params.dataId); const xId = xData.id; const indicesData = backend.dataIdMap.get(indices.dataId); const indicesId = indicesData.id; @@ -68,14 +64,14 @@ function gatherNd(args: {backend: BackendWasm, inputs: GatherNdInputs}): const outId = backend.dataIdMap.get(out.dataId).id; wasmGatherNd( - xId, CppDType[x.dtype], indicesId, numSlices, sliceRank, sliceSize, + xId, CppDType[params.dtype], indicesId, numSlices, sliceRank, sliceSize, stridesBytes, outId); return out; } registerKernel({ - kernelName: 'GatherNd', + kernelName: GatherNd, backendName: 'wasm', setupFunc: setup, kernelFunc: gatherNd diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 1d7f75929be..8106eedce52 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -197,8 +197,8 @@ export interface FusedBatchNormAttrs { varianceEpsilon: number; } -export const GatherND = 'GatherND'; -export type GatherNDInputs = Pick; +export const GatherNd = 'GatherNd'; +export type GatherNdInputs = Pick; export const Greater = 'Greater'; export type GreaterInputs = BinaryInputs; diff --git a/tfjs-core/src/ops/gather_nd.ts b/tfjs-core/src/ops/gather_nd.ts index 197702c95e7..d4e37404e56 100644 --- a/tfjs-core/src/ops/gather_nd.ts +++ b/tfjs-core/src/ops/gather_nd.ts @@ -15,7 +15,7 @@ * ============================================================================= */ import {ENGINE, ForwardFunc} from '../engine'; -import {GatherND, GatherNDInputs} from '../kernel_names'; +import {GatherNd, GatherNdInputs} from '../kernel_names'; import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; @@ -67,10 +67,10 @@ function gatherND_(x: Tensor|TensorLike, indices: Tensor|TensorLike): Tensor { return backend.gatherND($x, $indices); }; - const inputs: GatherNDInputs = {params: $x, indices: $indices}; + const inputs: GatherNdInputs = {params: $x, indices: $indices}; return ENGINE.runKernelFunc( - forward, inputs as {} as NamedTensorMap, null /* gradient */, GatherND); + forward, inputs as {} as NamedTensorMap, null /* gradient */, GatherNd); } export const gatherND = op({gatherND_});