diff --git a/tfjs-backend-wasm/src/kernels/Gather.ts b/tfjs-backend-wasm/src/kernels/GatherV2.ts similarity index 80% rename from tfjs-backend-wasm/src/kernels/Gather.ts rename to tfjs-backend-wasm/src/kernels/GatherV2.ts index 7ba83995701..80407e7034e 100644 --- a/tfjs-backend-wasm/src/kernels/Gather.ts +++ b/tfjs-backend-wasm/src/kernels/GatherV2.ts @@ -15,19 +15,11 @@ * ============================================================================= */ -import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelFunc, registerKernel, Tensor, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; -import {CppDType} from './types'; - -interface GatherInputs extends NamedTensorInfoMap { - x: TensorInfo; - indices: TensorInfo; -} -interface GatherAttrs extends NamedAttrMap { - axis: number; -} +import {CppDType} from './types'; let wasmGather: (xId: number, dtype: CppDType, xStrides: Uint8Array, stridesSize: number, @@ -47,8 +39,8 @@ function setup(backend: BackendWasm): void { ]); } -function gather( - args: {backend: BackendWasm, inputs: GatherInputs, attrs: GatherAttrs}): +function gatherV2( + args: {backend: BackendWasm, inputs: GatherV2Inputs, attrs: GatherV2Attrs}): TensorInfo { const {backend, inputs, attrs} = args; const {x, indices} = inputs; @@ -80,12 +72,18 @@ function gather( xId, CppDType[x.dtype], xStridesBytes, stridesSize, indicesId, axis, outStridesBytes, outId); + // reshape + const parsedAxis = util.parseAxisParam(axis, x.shape)[0]; + const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo( + x as Tensor, indices as Tensor, parsedAxis); + + out.shape = shapeInfo.outputShape; return out; } registerKernel({ - kernelName: 'Gather', + kernelName: GatherV2, backendName: 'wasm', setupFunc: setup, - kernelFunc: gather + kernelFunc: gatherV2 as {} as KernelFunc }); diff --git a/tfjs-backend-wasm/src/kernels/ScatterNd.ts b/tfjs-backend-wasm/src/kernels/ScatterNd.ts index dde34bf7c91..0beb5bb79ac 100644 --- a/tfjs-backend-wasm/src/kernels/ScatterNd.ts +++ b/tfjs-backend-wasm/src/kernels/ScatterNd.ts @@ -15,19 +15,11 @@ * ============================================================================= */ -import {NamedAttrMap, NamedTensorInfoMap, registerKernel, scatter_util, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {KernelFunc, registerKernel, scatter_util, ScatterNd, ScatterNdAttrs, ScatterNdInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; -import {CppDType} from './types'; - -interface ScatterNdInputs extends NamedTensorInfoMap { - indices: TensorInfo; - updates: TensorInfo; -} -interface ScatterNdAttrs extends NamedAttrMap { - shape: number[]; -} +import {CppDType} from './types'; let wasmScatterNd: ( indicesId: number, updatesId: number, dtype: CppDType, sliceRank: number, @@ -81,8 +73,8 @@ function scatterNd( } registerKernel({ - kernelName: 'ScatterNd', + kernelName: ScatterNd, backendName: 'wasm', setupFunc: setup, - kernelFunc: scatterNd + kernelFunc: scatterNd as {} as KernelFunc }); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index c8cea76c002..3233938c308 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -41,7 +41,7 @@ import './FloorDiv'; import './FusedBatchNorm'; import './FusedConv2D'; import './FusedDepthwiseConv2D'; -import './Gather'; +import './GatherV2'; import './GatherNd'; import './Greater'; import './GreaterEqual'; diff --git a/tfjs-core/src/gradients/GatherV2_grad.ts b/tfjs-core/src/gradients/GatherV2_grad.ts new file mode 100644 index 00000000000..fea887393ed --- /dev/null +++ b/tfjs-core/src/gradients/GatherV2_grad.ts @@ -0,0 +1,85 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {GatherV2, GatherV2Attrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {getUndoAxesPermutation} from '../ops/axis_util'; +import {reshape} from '../ops/reshape'; +import {transpose} from '../ops/transpose'; +import {unsortedSegmentSum} from '../ops/unsorted_segment_sum'; +import {Tensor, Tensor1D} from '../tensor'; +import {parseAxisParam} from '../util'; + +export const gatherGradConfig: GradConfig = { + kernelName: GatherV2, + inputsToSave: ['x', 'indices'], + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const [x, indices] = saved; + const {axis} = attrs as {} as GatherV2Attrs; + + const parsedAxis = parseAxisParam(axis, x.shape)[0]; + + const derX = () => { + const paramsShape = x.shape; + const indicesSize = indices.size; + + const outerShape = paramsShape.slice(0, parsedAxis); + const outerDims = outerShape.length; + const innerShape = paramsShape.slice(axis, paramsShape.length).slice(1); + const innerDims = innerShape.length; + + const outerAxesIndices = arrayRange(0, outerDims); + const innerAxesIndices = + arrayRange(outerDims + 1, outerDims + 1 + innerDims); + + const valuesShape = arrayConcat([outerShape, [indicesSize], innerShape]); + + const values = reshape(dy, valuesShape); + const reshapedIndices = reshape(indices, [indicesSize]); + + const transposeDims = + arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]); + const valuesTranspose = transpose(values, transposeDims); + let paramsGrad = unsortedSegmentSum( + valuesTranspose, reshapedIndices as Tensor1D, x.shape[parsedAxis]); + + const invertTransposeDims = getUndoAxesPermutation(transposeDims); + paramsGrad = transpose(paramsGrad, invertTransposeDims); + + return paramsGrad; + }; + return {x: derX, indices: () => indices}; + } +}; + +function arrayRange(start: number, stop: number): number[] { + const result = []; + for (let i = start; i < stop; ++i) { + result.push(i); + } + return result; +} + +function arrayConcat(arrays: number[][]): number[] { + const result = []; + for (let i = 0; i < arrays.length; ++i) { + for (let j = 0; j < arrays[i].length; ++j) { + result.push(arrays[i][j]); + } + } + return result; +} diff --git a/tfjs-core/src/gradients/UnsortedSegmentSum_grad.ts b/tfjs-core/src/gradients/UnsortedSegmentSum_grad.ts new file mode 100644 index 00000000000..982194417a5 --- /dev/null +++ b/tfjs-core/src/gradients/UnsortedSegmentSum_grad.ts @@ -0,0 +1,56 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {UnsortedSegmentSum} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {expandDims} from '../ops/expand_dims'; +import {gather} from '../ops/gather'; +import {greaterEqual} from '../ops/greater_equal'; +import {logicalAnd} from '../ops/logical_and'; +import {maximum} from '../ops/maximum'; +import {ones, scalar, zerosLike} from '../ops/tensor_ops'; +import {where} from '../ops/where'; +import {Tensor, Tensor1D} from '../tensor'; + +export const unsortedSegmentSumGradConfig: GradConfig = { + kernelName: UnsortedSegmentSum, + inputsToSave: ['segmentIds'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [segmentIds] = saved; + + const derX = () => { + return gatherDropNegatives(dy, segmentIds as Tensor1D); + }; + return {x: derX}; + } +}; + +function gatherDropNegatives(x: T, indices: Tensor1D) { + // Helper function for unsorted segment ops. Gathers params for + // positive segment ids and gathers 0 for inputs with negative segment id. + // Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py + const zeroClippedIndices = maximum(indices, zerosLike(indices)); + const gathered = gather(x, zeroClippedIndices as Tensor1D); + let isPositive = greaterEqual(indices, scalar(0, 'int32')); + const numIters = gathered.rank - isPositive.rank; + for (let i = 0; i < numIters; ++i) { + isPositive = expandDims(isPositive, i + 1); + } + isPositive = logicalAnd(isPositive, ones(gathered.shape, 'bool')); + const zeroSlice = zerosLike(gathered); + return where(isPositive, gathered, zeroSlice); +} diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index ac5d8e59300..5f4d8c2865c 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -273,6 +273,12 @@ export interface FusedBatchNormAttrs { varianceEpsilon: number; } +export const GatherV2 = 'GatherV2'; +export type GatherV2Inputs = Pick; +export interface GatherV2Attrs { + axis: number; +} + export const GatherNd = 'GatherNd'; export type GatherNdInputs = Pick; @@ -499,6 +505,12 @@ export interface ReverseAttrs { dims: number|number[]; } +export const ScatterNd = 'ScatterNd'; +export type ScatterNdInputs = Pick; +export interface ScatterNdAttrs { + shape: number[]; +} + export const SelectV2 = 'SelectV2'; export type SelectV2Inputs = Pick; @@ -553,6 +565,13 @@ export interface UnpackAttrs { axis: number; } +export const UnsortedSegmentSum = 'UnsortedSegmentSum'; +export type UnsortedSegmentSumInputs = + Pick; +export interface UnsortedSegmentSumAttrs { + numSegments: number; +} + /** * TensorFlow.js-only kernels */ diff --git a/tfjs-core/src/ops/array_ops_test.ts b/tfjs-core/src/ops/array_ops_test.ts index 0179f106e9b..c124102dd04 100644 --- a/tfjs-core/src/ops/array_ops_test.ts +++ b/tfjs-core/src/ops/array_ops_test.ts @@ -1443,496 +1443,6 @@ describeWithFlags('clone', ALL_ENVS, () => { }); }); -describeWithFlags('gather', ALL_ENVS, () => { - it('1D (gather), scalar indices', async () => { - const t = tf.tensor1d([1, 2, 3]); - - const t2 = tf.gather(t, tf.scalar(1, 'int32'), 0); - - expect(t2.shape).toEqual([]); - expectArraysClose(await t2.data(), [2]); - }); - - it('1D (gather), 1D indices', async () => { - const t = tf.tensor1d([1, 2, 3]); - - const t2 = tf.gather(t, tf.tensor1d([0, 2, 0, 1], 'int32'), 0); - - expect(t2.shape).toEqual([4]); - expectArraysClose(await t2.data(), [1, 3, 1, 2]); - }); - - it('1D (gather), 2D indices', async () => { - const t = tf.tensor1d([1, 2, 3]); - - const t2 = tf.gather(t, tf.tensor2d([0, 2, 0, 1], [1, 4], 'int32'), 0); - - expect(t2.shape).toEqual([1, 4]); - expectArraysClose(await t2.data(), [1, 3, 1, 2]); - }); - - it('2D (gather), scalar indices', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); - let t2 = tf.gather(t, tf.scalar(1, 'int32'), 0); - expect(t2.shape).toEqual([2]); - expectArraysClose(await t2.data(), [2, 22]); - - t2 = tf.gather(t, tf.scalar(1, 'int32'), 1); - expect(t2.shape).toEqual([2]); - expectArraysClose(await t2.data(), [11, 22]); - }); - - it('2D (gather), 1D indices', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); - let t2 = tf.gather(t, tf.tensor1d([1, 0, 0, 1], 'int32'), 0); - expect(t2.shape).toEqual([4, 2]); - expectArraysClose(await t2.data(), [2, 22, 1, 11, 1, 11, 2, 22]); - - t2 = tf.gather(t, tf.tensor1d([1, 0, 0, 1], 'int32'), 1); - expect(t2.shape).toEqual([2, 4]); - expectArraysClose(await t2.data(), [11, 1, 1, 11, 22, 2, 2, 22]); - }); - - it('2D (gather), 2D indices', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); - let t2 = tf.gather(t, tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'), 0); - expect(t2.shape).toEqual([2, 2, 2]); - expectArraysClose(await t2.data(), [2, 22, 1, 11, 1, 11, 2, 22]); - - t2 = tf.gather(t, tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'), 1); - expect(t2.shape).toEqual([2, 2, 2]); - expectArraysClose(await t2.data(), [11, 1, 1, 11, 22, 2, 2, 22]); - }); - - it('3D (gather), 1D indices', async () => { - const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); - - const t2 = tf.gather(t, tf.tensor1d([1, 0, 0, 1], 'int32'), 2); - - expect(t2.shape).toEqual([2, 2, 4]); - expectArraysClose( - await t2.data(), [2, 1, 1, 2, 4, 3, 3, 4, 6, 5, 5, 6, 8, 7, 7, 8]); - }); - - it('3D (gather), 2D indices', async () => { - const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); - - const t2 = tf.gather(t, tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'), 2); - - expect(t2.shape).toEqual([2, 2, 2, 2]); - expectArraysClose( - await t2.data(), [2, 1, 1, 2, 4, 3, 3, 4, 6, 5, 5, 6, 8, 7, 7, 8]); - }); - - it('bool (gather), 1D indices', async () => { - const t = tf.tensor1d([true, false, true], 'bool'); - - const t2 = tf.gather(t, tf.tensor1d([0, 2, 0, 1], 'int32'), 0); - - expect(t2.shape).toEqual([4]); - expect(t2.dtype).toBe('bool'); - expect(await t2.data()).toEqual(new Uint8Array([1, 1, 1, 0])); - }); - - it('bool (gather), 2D indices', async () => { - const t = tf.tensor1d([true, false, true], 'bool'); - - const t2 = tf.gather(t, tf.tensor2d([0, 2, 0, 1], [2, 2], 'int32'), 0); - - expect(t2.shape).toEqual([2, 2]); - expect(t2.dtype).toBe('bool'); - expect(await t2.data()).toEqual(new Uint8Array([1, 1, 1, 0])); - }); - - it('int32 (gather), 1D indices', async () => { - const t = tf.tensor1d([1, 2, 5], 'int32'); - - const t2 = tf.gather(t, tf.tensor1d([0, 2, 0, 1], 'int32'), 0); - - expect(t2.shape).toEqual([4]); - expect(t2.dtype).toBe('int32'); - expect(await t2.data()).toEqual(new Int32Array([1, 5, 1, 2])); - }); - - it('int32 (gather), 2D indices', async () => { - const t = tf.tensor1d([1, 2, 5], 'int32'); - - const t2 = tf.gather(t, tf.tensor2d([0, 2, 0, 1], [2, 2], 'int32'), 0); - - expect(t2.shape).toEqual([2, 2]); - expect(t2.dtype).toBe('int32'); - expect(await t2.data()).toEqual(new Int32Array([1, 5, 1, 2])); - }); - - it('propagates NaNs', async () => { - const t = tf.tensor1d([1, 2, NaN]); - - const t2 = tf.gather(t, tf.tensor1d([0, 2, 0, 1], 'int32'), 0); - - expect(t2.shape).toEqual([4]); - expectArraysClose(await t2.data(), [1, NaN, 1, 2]); - }); - - it('chaining, axis=1', () => { - const x = tf.zeros([2, 4, 6]); - // [0, 2, 4] - const indices = tf.range(0, 6, 2, 'int32'); - const axis = 2; - expect(x.gather(indices, axis).shape).toEqual([2, 4, 3]); - }); - - it('indices not int32 throws error', () => { - const x = tf.zeros([2, 4, 6]); - // [0, 2, 4] - const indices = tf.range(0, 6, 2); - const axis = 2; - expect(() => x.gather(indices, axis)).toThrowError(); - }); - - it('throws when passed x as a non-tensor', () => { - expect(() => tf.gather({} as tf.Tensor, tf.tensor1d([1]))) - .toThrowError(/Argument 'x' passed to 'gather' must be a Tensor/); - }); - - it('throws when passed indices as a non-tensor', () => { - // tslint:disable-next-line:no-any - expect(() => tf.gather(tf.tensor1d([1]), {} as any)) - .toThrowError(/Argument 'indices' passed to 'gather' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const res = tf.gather([1, 2, 3], [0, 2, 0, 1], 0); - expect(res.shape).toEqual([4]); - expectArraysClose(await res.data(), [1, 3, 1, 2]); - }); - - it('gradient 1D (gather), 1D indices', async () => { - const t = tf.tensor1d([1, 2, 3]); - const indices = tf.tensor1d([0, 2, 0, 1], 'int32'); - const dy = tf.tensor([3, 4, 5, 6]); - - const gradients = tf.grad(t => tf.gather(t, indices))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose(await gradients.data(), [8, 6, 4]); - }); - - it('gradient with clones', () => { - const t = tf.tensor1d([1, 2, 3]); - const indices = tf.tensor1d([0, 2, 0, 1], 'int32'); - const gradF = tf.grad(t => tf.gather(t.clone(), indices.clone()).clone()); - const dt = gradF(t); - expect(dt.shape).toEqual(t.shape); - }); - - it('gradient 1D (gather), 2D indices', async () => { - const t = tf.tensor1d([1, 2, 3]); - const indices = tf.tensor2d([0, 2, 0, 1], [2, 2], 'int32'); - const dy = tf.tensor2d([3, 4, 5, 6], [2, 2]); - - const gradients = tf.grad(t => tf.gather(t, indices))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose(await gradients.data(), [8, 6, 4]); - }); - - it('gradient 2D (gather) axis=0 shape=[2, 2] 1D indices', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); - const indices = tf.tensor1d([1, 0, 0, 1], 'int32'); - const dy = tf.tensor([3, 4, 5, 6, 7, 8, 9, 10], [4, 2]); - const axis = 0; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose(await gradients.data(), [12, 14, 12, 14]); - }); - - it('gradient 2D (gather) axis=0 shape=[2, 2] 2D indices', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); - const indices = tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'); - const dy = tf.tensor([3, 4, 5, 6, 7, 8, 9, 10], [2, 2, 2]); - const axis = 0; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose(await gradients.data(), [12, 14, 12, 14]); - }); - - it('gradient 2D (gather) axis=0 shape=[4, 1] 1D indices', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [4, 1]); - const indices = tf.tensor1d([1, 0, 0, 1], 'int32'); - const dy = tf.tensor([23, 7, 19, 13], [4, 1]); - const axis = 0; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose(await gradients.data(), [26, 36, 0, 0]); - }); - - it('gradient 2D (gather) axis=0 shape=[4, 1] 2D indices', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [4, 1]); - const indices = tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'); - const dy = tf.tensor([23, 7, 19, 13], [2, 2, 1]); - const axis = 0; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose(await gradients.data(), [26, 36, 0, 0]); - }); - - it('gradient 2D (gather) axis=1 shape=[2, 2] 1D indices', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); - const indices = tf.tensor1d([1, 0, 0, 1], 'int32'); - const dy = tf.tensor([3, 4, 5, 6, 7, 8, 9, 10], [2, 4]); - const axis = 1; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose(await gradients.data(), [9, 9, 17, 17]); - }); - - it('gradient 2D (gather) axis=1 shape=[2, 2] 2D indices', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); - const indices = tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'); - const dy = tf.tensor([3, 4, 5, 6, 7, 8, 9, 10], [2, 2, 2]); - const axis = 1; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose(await gradients.data(), [9, 9, 17, 17]); - }); - - it('gradient 2D (gather) axis=1 shape=[4, 1] 1D indices', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [4, 1]); - const indices = tf.tensor1d([0, 0, 0, 0], 'int32'); - const dy = tf.tensor( - [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [4, 4]); - const axis = 1; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose(await gradients.data(), [18, 34, 50, 66]); - }); - - it('gradient 2D (gather) axis=1 shape=[4, 1] 2D indices', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [4, 1]); - const indices = tf.tensor2d([0, 0, 0, 0], [2, 2], 'int32'); - const dy = tf.tensor( - [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [4, 2, 2]); - const axis = 1; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose(await gradients.data(), [18, 34, 50, 66]); - }); - - it('gradient 3D (gather) axis=0 shape=[2, 3, 2] 1D indices', async () => { - const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]); - const indices = tf.tensor1d([1, 0, 0, 1], 'int32'); - const dy = tf.tensor( - [ - 2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, - 4, 15, 12, -7, 18, 19, 2, 21, 6, 23, 24, 25 - ], - [4, 3, 2]); - const axis = 0; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose( - await gradients.data(), - [5, 33, 12.01, -7, 30, 32, 4, 18, 10, 38, 30, 25.7]); - }); - - it('gradient 3D (gather) axis=0 shape=[2, 3, 2] 2D indices', async () => { - const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]); - const indices = tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'); - const dy = tf.tensor( - [ - 2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, - 4, 15, 12, -7, 18, 19, 2, 21, 6, 23, 24, 25 - ], - [2, 2, 3, 2]); - const axis = 0; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose( - await gradients.data(), - [5, 33, 12.01, -7, 30, 32, 4, 18, 10, 38, 30, 25.7]); - }); - - it('gradient 3D (gather) axis=0 shape=[1, 4, 4]', async () => { - const t = tf.tensor3d( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [1, 4, 4]); - const indices = tf.tensor1d([0, 0], 'int32'); - const dy = tf.tensor( - [ - 2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, 4, 15, 12, -7, - 18, 19, 2, 21, 6, 23, 24, 25, 101, 31, 34, 54, 1, 0, -3, -4 - ], - [2, 4, 4]); - const axis = 0; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose( - await gradients.data(), - [20, 16, 6, 36, 12, 23.7, 25, 43, 101.01, 31, 46, 67, 5, 15, 9, -11]); - }); - - it('gradient 3D (gather) axis=0 shape=[1, 4, 4] 1D indices', async () => { - const t = tf.tensor3d( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [1, 4, 4]); - const indices = tf.tensor1d([0, 0], 'int32'); - const dy = tf.tensor( - [ - 2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, 4, 15, 12, -7, - 18, 19, 2, 21, 6, 23, 24, 25, 101, 31, 34, 54, 1, 0, -3, -4 - ], - [2, 4, 4]); - const axis = 0; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose( - await gradients.data(), - [20, 16, 6, 36, 12, 23.7, 25, 43, 101.01, 31, 46, 67, 5, 15, 9, -11]); - }); - - it('gradient 3D (gather) axis=1 shape=[2, 3, 2] 2D indices', async () => { - const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]); - const indices = tf.tensor2d([1, 2, 2, 1], [2, 2], 'int32'); - const dy = tf.tensor( - [2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, 4, 15, 12, -7], - [2, 2, 2, 2]); - const axis = 1; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose( - await gradients.data(), - [0, 0, 3, 15, 10, 15.7, 0, 0, 12.01, -7, 16, 28]); - }); - - it('gradient 3D (gather) axis=1 shape=[1, 4, 4] 1D indices', async () => { - const t = tf.tensor3d( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [1, 4, 4]); - const indices = tf.tensor1d([1, 2, 2, 1], 'int32'); - const dy = tf.tensor( - [2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, 4, 15, 12, -7], - [1, 4, 4]); - const axis = 1; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose( - await gradients.data(), - [0, 0, 0, 0, 6, 12, 16, 8, 6.01, .7, 13, 31, 0, 0, 0, 0]); - }); - - it('gradient 3D (gather) axis=1 shape=[1, 4, 4] 2D indices', async () => { - const t = tf.tensor3d( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [1, 4, 4]); - const indices = tf.tensor2d([1, 2, 2, 1], [2, 2], 'int32'); - const dy = tf.tensor( - [2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, 4, 15, 12, -7], - [1, 2, 2, 4]); - const axis = 1; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose( - await gradients.data(), - [0, 0, 0, 0, 6, 12, 16, 8, 6.01, .7, 13, 31, 0, 0, 0, 0]); - }); - - it('gradient 3D (gather) axis=2 shape=[2, 3, 2] 1D indices', async () => { - const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]); - const indices = tf.tensor1d([1, 0, 1, 0], 'int32'); - const dy = tf.tensor( - [ - 2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, - 4, 15, 12, -7, 18, 19, 2, 21, 6, 23, 24, 25 - ], - [2, 3, 4]); - const axis = 2; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose( - await gradients.data(), - [12, 6, 18.7, 7, 13, 12.01, 8, 16, 40, 20, 48, 30]); - }); - - it('gradient 3D (gather) axis=2 shape=[2, 3, 2] 2D indices', async () => { - const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]); - const indices = tf.tensor2d([1, 0, 1, 0], [2, 2], 'int32'); - const dy = tf.tensor( - [ - 2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, - 4, 15, 12, -7, 18, 19, 2, 21, 6, 23, 24, 25 - ], - [2, 3, 2, 2]); - const axis = 2; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose( - await gradients.data(), - [12, 6, 18.7, 7, 13, 12.01, 8, 16, 40, 20, 48, 30]); - }); - - it('gradient 3D (gather) axis=2 shape=[4, 1, 4] 1D indices', async () => { - const t = tf.tensor3d( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [4, 1, 4]); - const indices = tf.tensor1d([1, 3, 1], 'int32'); - const dy = - tf.tensor([2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 4, 15], [4, 1, 3]); - const axis = 2; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose( - await gradients.data(), - [0, 6, 0, -3, 0, 15.7, 0, 6, 0, 1.01, 0, 18, 0, 15, 0, 4]); - }); - - it('gradient 3D (gather) axis=2 shape=[4, 1, 4] 2D indices', async () => { - const t = tf.tensor3d( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [4, 1, 4]); - const indices = tf.tensor2d([1, 3, 1], [1, 3], 'int32'); - const dy = - tf.tensor([2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 4, 15], [4, 1, 1, 3]); - const axis = 2; - - const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); - - expect(gradients.shape).toEqual(t.shape); - expectArraysClose( - await gradients.data(), - [0, 6, 0, -3, 0, 15.7, 0, 6, 0, 1.01, 0, 18, 0, 15, 0, 4]); - }); -}); - describeWithFlags('linspace', ALL_ENVS, () => { it('start stop', async () => { const a = tf.linspace(1, 10, 10); diff --git a/tfjs-core/src/ops/boolean_mask.ts b/tfjs-core/src/ops/boolean_mask.ts index cf455d6146f..03ad2550d07 100644 --- a/tfjs-core/src/ops/boolean_mask.ts +++ b/tfjs-core/src/ops/boolean_mask.ts @@ -20,7 +20,7 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {gather} from './segment_ops'; +import {gather} from './gather'; import {whereAsync} from './where_async'; /** diff --git a/tfjs-core/src/ops/gather.ts b/tfjs-core/src/ops/gather.ts new file mode 100644 index 00000000000..d990a3969c6 --- /dev/null +++ b/tfjs-core/src/ops/gather.ts @@ -0,0 +1,74 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {GatherV2, GatherV2Attrs, GatherV2Inputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import {parseAxisParam} from '../util'; + +import {op} from './operation'; +import {collectGatherOpShapeInfo} from './segment_util'; + +/** + * Gather slices from tensor `x`'s axis `axis` according to `indices`. + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * const indices = tf.tensor1d([1, 3, 3], 'int32'); + * + * x.gather(indices).print(); + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * const indices = tf.tensor1d([1, 1, 0], 'int32'); + * + * x.gather(indices).print(); + * ``` + * @param x The input tensor whose slices to be gathered. + * @param indices The indices of the values to extract. + * @param axis The axis over which to select values. Defaults to 0. + */ +/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ +function gather_( + x: T|TensorLike, indices: Tensor|TensorLike, axis = 0): T { + const $x = convertToTensor(x, 'x', 'gather'); + const $indices = convertToTensor(indices, 'indices', 'gather', 'int32'); + + const inputs: GatherV2Inputs = {x: $x, indices: $indices}; + const attrs: GatherV2Attrs = {axis}; + + const forward: ForwardFunc = (backend, save) => { + const parsedAxis = parseAxisParam(axis, $x.shape)[0]; + const shapeInfo = collectGatherOpShapeInfo($x, $indices, parsedAxis); + + const res = backend.gather($x, $indices.flatten(), parsedAxis); + save([$x, $indices]); + + return res.reshape(shapeInfo.outputShape); + }; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, GatherV2, + attrs as {} as NamedAttrMap) as T; +} + +export const gather = op({gather_}); diff --git a/tfjs-core/src/ops/gather_test.ts b/tfjs-core/src/ops/gather_test.ts new file mode 100644 index 00000000000..f6506c7a425 --- /dev/null +++ b/tfjs-core/src/ops/gather_test.ts @@ -0,0 +1,510 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('gather', ALL_ENVS, () => { + it('1D (gather), scalar indices', async () => { + const t = tf.tensor1d([1, 2, 3]); + + const t2 = tf.gather(t, tf.scalar(1, 'int32'), 0); + + expect(t2.shape).toEqual([]); + expectArraysClose(await t2.data(), [2]); + }); + + it('1D (gather), 1D indices', async () => { + const t = tf.tensor1d([1, 2, 3]); + + const t2 = tf.gather(t, tf.tensor1d([0, 2, 0, 1], 'int32'), 0); + + expect(t2.shape).toEqual([4]); + expectArraysClose(await t2.data(), [1, 3, 1, 2]); + }); + + it('1D (gather), 2D indices', async () => { + const t = tf.tensor1d([1, 2, 3]); + + const t2 = tf.gather(t, tf.tensor2d([0, 2, 0, 1], [1, 4], 'int32'), 0); + + expect(t2.shape).toEqual([1, 4]); + expectArraysClose(await t2.data(), [1, 3, 1, 2]); + }); + + it('2D (gather), scalar indices', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); + let t2 = tf.gather(t, tf.scalar(1, 'int32'), 0); + expect(t2.shape).toEqual([2]); + expectArraysClose(await t2.data(), [2, 22]); + + t2 = tf.gather(t, tf.scalar(1, 'int32'), 1); + expect(t2.shape).toEqual([2]); + expectArraysClose(await t2.data(), [11, 22]); + }); + + it('2D (gather), 1D indices', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); + let t2 = tf.gather(t, tf.tensor1d([1, 0, 0, 1], 'int32'), 0); + expect(t2.shape).toEqual([4, 2]); + expectArraysClose(await t2.data(), [2, 22, 1, 11, 1, 11, 2, 22]); + + t2 = tf.gather(t, tf.tensor1d([1, 0, 0, 1], 'int32'), 1); + expect(t2.shape).toEqual([2, 4]); + expectArraysClose(await t2.data(), [11, 1, 1, 11, 22, 2, 2, 22]); + }); + + it('2D (gather), 2D indices', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); + let t2 = tf.gather(t, tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'), 0); + expect(t2.shape).toEqual([2, 2, 2]); + expectArraysClose(await t2.data(), [2, 22, 1, 11, 1, 11, 2, 22]); + + t2 = tf.gather(t, tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'), 1); + expect(t2.shape).toEqual([2, 2, 2]); + expectArraysClose(await t2.data(), [11, 1, 1, 11, 22, 2, 2, 22]); + }); + + it('3D (gather), 1D indices', async () => { + const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); + + const t2 = tf.gather(t, tf.tensor1d([1, 0, 0, 1], 'int32'), 2); + + expect(t2.shape).toEqual([2, 2, 4]); + expectArraysClose( + await t2.data(), [2, 1, 1, 2, 4, 3, 3, 4, 6, 5, 5, 6, 8, 7, 7, 8]); + }); + + it('3D (gather), 2D indices', async () => { + const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); + + const t2 = tf.gather(t, tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'), 2); + + expect(t2.shape).toEqual([2, 2, 2, 2]); + expectArraysClose( + await t2.data(), [2, 1, 1, 2, 4, 3, 3, 4, 6, 5, 5, 6, 8, 7, 7, 8]); + }); + + it('bool (gather), 1D indices', async () => { + const t = tf.tensor1d([true, false, true], 'bool'); + + const t2 = tf.gather(t, tf.tensor1d([0, 2, 0, 1], 'int32'), 0); + + expect(t2.shape).toEqual([4]); + expect(t2.dtype).toBe('bool'); + expect(await t2.data()).toEqual(new Uint8Array([1, 1, 1, 0])); + }); + + it('bool (gather), 2D indices', async () => { + const t = tf.tensor1d([true, false, true], 'bool'); + + const t2 = tf.gather(t, tf.tensor2d([0, 2, 0, 1], [2, 2], 'int32'), 0); + + expect(t2.shape).toEqual([2, 2]); + expect(t2.dtype).toBe('bool'); + expect(await t2.data()).toEqual(new Uint8Array([1, 1, 1, 0])); + }); + + it('int32 (gather), 1D indices', async () => { + const t = tf.tensor1d([1, 2, 5], 'int32'); + + const t2 = tf.gather(t, tf.tensor1d([0, 2, 0, 1], 'int32'), 0); + + expect(t2.shape).toEqual([4]); + expect(t2.dtype).toBe('int32'); + expect(await t2.data()).toEqual(new Int32Array([1, 5, 1, 2])); + }); + + it('int32 (gather), 2D indices', async () => { + const t = tf.tensor1d([1, 2, 5], 'int32'); + + const t2 = tf.gather(t, tf.tensor2d([0, 2, 0, 1], [2, 2], 'int32'), 0); + + expect(t2.shape).toEqual([2, 2]); + expect(t2.dtype).toBe('int32'); + expect(await t2.data()).toEqual(new Int32Array([1, 5, 1, 2])); + }); + + it('propagates NaNs', async () => { + const t = tf.tensor1d([1, 2, NaN]); + + const t2 = tf.gather(t, tf.tensor1d([0, 2, 0, 1], 'int32'), 0); + + expect(t2.shape).toEqual([4]); + expectArraysClose(await t2.data(), [1, NaN, 1, 2]); + }); + + it('chaining, axis=1', () => { + const x = tf.zeros([2, 4, 6]); + // [0, 2, 4] + const indices = tf.range(0, 6, 2, 'int32'); + const axis = 2; + expect(x.gather(indices, axis).shape).toEqual([2, 4, 3]); + }); + + it('indices not int32 throws error', () => { + const x = tf.zeros([2, 4, 6]); + // [0, 2, 4] + const indices = tf.range(0, 6, 2); + const axis = 2; + expect(() => x.gather(indices, axis)).toThrowError(); + }); + + it('throws when passed x as a non-tensor', () => { + expect(() => tf.gather({} as tf.Tensor, tf.tensor1d([1]))) + .toThrowError(/Argument 'x' passed to 'gather' must be a Tensor/); + }); + + it('throws when passed indices as a non-tensor', () => { + // tslint:disable-next-line:no-any + expect(() => tf.gather(tf.tensor1d([1]), {} as any)) + .toThrowError(/Argument 'indices' passed to 'gather' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const res = tf.gather([1, 2, 3], [0, 2, 0, 1], 0); + expect(res.shape).toEqual([4]); + expectArraysClose(await res.data(), [1, 3, 1, 2]); + }); + + it('gradient 1D (gather), 1D indices', async () => { + const t = tf.tensor1d([1, 2, 3]); + const indices = tf.tensor1d([0, 2, 0, 1], 'int32'); + const dy = tf.tensor([3, 4, 5, 6]); + + const gradients = tf.grad(t => tf.gather(t, indices))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose(await gradients.data(), [8, 6, 4]); + }); + + it('gradient with clones', () => { + const t = tf.tensor1d([1, 2, 3]); + const indices = tf.tensor1d([0, 2, 0, 1], 'int32'); + const gradF = tf.grad(t => tf.gather(t.clone(), indices.clone()).clone()); + const dt = gradF(t); + expect(dt.shape).toEqual(t.shape); + }); + + it('gradient 1D (gather), 2D indices', async () => { + const t = tf.tensor1d([1, 2, 3]); + const indices = tf.tensor2d([0, 2, 0, 1], [2, 2], 'int32'); + const dy = tf.tensor2d([3, 4, 5, 6], [2, 2]); + + const gradients = tf.grad(t => tf.gather(t, indices))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose(await gradients.data(), [8, 6, 4]); + }); + + it('gradient 2D (gather) axis=0 shape=[2, 2] 1D indices', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); + const indices = tf.tensor1d([1, 0, 0, 1], 'int32'); + const dy = tf.tensor([3, 4, 5, 6, 7, 8, 9, 10], [4, 2]); + const axis = 0; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose(await gradients.data(), [12, 14, 12, 14]); + }); + + it('gradient 2D (gather) axis=0 shape=[2, 2] 2D indices', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); + const indices = tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'); + const dy = tf.tensor([3, 4, 5, 6, 7, 8, 9, 10], [2, 2, 2]); + const axis = 0; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose(await gradients.data(), [12, 14, 12, 14]); + }); + + it('gradient 2D (gather) axis=0 shape=[4, 1] 1D indices', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [4, 1]); + const indices = tf.tensor1d([1, 0, 0, 1], 'int32'); + const dy = tf.tensor([23, 7, 19, 13], [4, 1]); + const axis = 0; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose(await gradients.data(), [26, 36, 0, 0]); + }); + + it('gradient 2D (gather) axis=0 shape=[4, 1] 2D indices', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [4, 1]); + const indices = tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'); + const dy = tf.tensor([23, 7, 19, 13], [2, 2, 1]); + const axis = 0; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose(await gradients.data(), [26, 36, 0, 0]); + }); + + it('gradient 2D (gather) axis=1 shape=[2, 2] 1D indices', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); + const indices = tf.tensor1d([1, 0, 0, 1], 'int32'); + const dy = tf.tensor([3, 4, 5, 6, 7, 8, 9, 10], [2, 4]); + const axis = 1; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose(await gradients.data(), [9, 9, 17, 17]); + }); + + it('gradient 2D (gather) axis=1 shape=[2, 2] 2D indices', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); + const indices = tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'); + const dy = tf.tensor([3, 4, 5, 6, 7, 8, 9, 10], [2, 2, 2]); + const axis = 1; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose(await gradients.data(), [9, 9, 17, 17]); + }); + + it('gradient 2D (gather) axis=1 shape=[4, 1] 1D indices', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [4, 1]); + const indices = tf.tensor1d([0, 0, 0, 0], 'int32'); + const dy = tf.tensor( + [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [4, 4]); + const axis = 1; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose(await gradients.data(), [18, 34, 50, 66]); + }); + + it('gradient 2D (gather) axis=1 shape=[4, 1] 2D indices', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [4, 1]); + const indices = tf.tensor2d([0, 0, 0, 0], [2, 2], 'int32'); + const dy = tf.tensor( + [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [4, 2, 2]); + const axis = 1; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose(await gradients.data(), [18, 34, 50, 66]); + }); + + it('gradient 3D (gather) axis=0 shape=[2, 3, 2] 1D indices', async () => { + const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]); + const indices = tf.tensor1d([1, 0, 0, 1], 'int32'); + const dy = tf.tensor( + [ + 2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, + 4, 15, 12, -7, 18, 19, 2, 21, 6, 23, 24, 25 + ], + [4, 3, 2]); + const axis = 0; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose( + await gradients.data(), + [5, 33, 12.01, -7, 30, 32, 4, 18, 10, 38, 30, 25.7]); + }); + + it('gradient 3D (gather) axis=0 shape=[2, 3, 2] 2D indices', async () => { + const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]); + const indices = tf.tensor2d([1, 0, 0, 1], [2, 2], 'int32'); + const dy = tf.tensor( + [ + 2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, + 4, 15, 12, -7, 18, 19, 2, 21, 6, 23, 24, 25 + ], + [2, 2, 3, 2]); + const axis = 0; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose( + await gradients.data(), + [5, 33, 12.01, -7, 30, 32, 4, 18, 10, 38, 30, 25.7]); + }); + + it('gradient 3D (gather) axis=0 shape=[1, 4, 4]', async () => { + const t = tf.tensor3d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [1, 4, 4]); + const indices = tf.tensor1d([0, 0], 'int32'); + const dy = tf.tensor( + [ + 2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, 4, 15, 12, -7, + 18, 19, 2, 21, 6, 23, 24, 25, 101, 31, 34, 54, 1, 0, -3, -4 + ], + [2, 4, 4]); + const axis = 0; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose( + await gradients.data(), + [20, 16, 6, 36, 12, 23.7, 25, 43, 101.01, 31, 46, 67, 5, 15, 9, -11]); + }); + + it('gradient 3D (gather) axis=0 shape=[1, 4, 4] 1D indices', async () => { + const t = tf.tensor3d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [1, 4, 4]); + const indices = tf.tensor1d([0, 0], 'int32'); + const dy = tf.tensor( + [ + 2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, 4, 15, 12, -7, + 18, 19, 2, 21, 6, 23, 24, 25, 101, 31, 34, 54, 1, 0, -3, -4 + ], + [2, 4, 4]); + const axis = 0; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose( + await gradients.data(), + [20, 16, 6, 36, 12, 23.7, 25, 43, 101.01, 31, 46, 67, 5, 15, 9, -11]); + }); + + it('gradient 3D (gather) axis=1 shape=[2, 3, 2] 2D indices', async () => { + const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]); + const indices = tf.tensor2d([1, 2, 2, 1], [2, 2], 'int32'); + const dy = tf.tensor( + [2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, 4, 15, 12, -7], + [2, 2, 2, 2]); + const axis = 1; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose( + await gradients.data(), + [0, 0, 3, 15, 10, 15.7, 0, 0, 12.01, -7, 16, 28]); + }); + + it('gradient 3D (gather) axis=1 shape=[1, 4, 4] 1D indices', async () => { + const t = tf.tensor3d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [1, 4, 4]); + const indices = tf.tensor1d([1, 2, 2, 1], 'int32'); + const dy = tf.tensor( + [2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, 4, 15, 12, -7], + [1, 4, 4]); + const axis = 1; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose( + await gradients.data(), + [0, 0, 0, 0, 6, 12, 16, 8, 6.01, .7, 13, 31, 0, 0, 0, 0]); + }); + + it('gradient 3D (gather) axis=1 shape=[1, 4, 4] 2D indices', async () => { + const t = tf.tensor3d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [1, 4, 4]); + const indices = tf.tensor2d([1, 2, 2, 1], [2, 2], 'int32'); + const dy = tf.tensor( + [2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, 4, 15, 12, -7], + [1, 2, 2, 4]); + const axis = 1; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose( + await gradients.data(), + [0, 0, 0, 0, 6, 12, 16, 8, 6.01, .7, 13, 31, 0, 0, 0, 0]); + }); + + it('gradient 3D (gather) axis=2 shape=[2, 3, 2] 1D indices', async () => { + const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]); + const indices = tf.tensor1d([1, 0, 1, 0], 'int32'); + const dy = tf.tensor( + [ + 2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, + 4, 15, 12, -7, 18, 19, 2, 21, 6, 23, 24, 25 + ], + [2, 3, 4]); + const axis = 2; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose( + await gradients.data(), + [12, 6, 18.7, 7, 13, 12.01, 8, 16, 40, 20, 48, 30]); + }); + + it('gradient 3D (gather) axis=2 shape=[2, 3, 2] 2D indices', async () => { + const t = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 3, 2]); + const indices = tf.tensor2d([1, 0, 1, 0], [2, 2], 'int32'); + const dy = tf.tensor( + [ + 2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 12, 13, + 4, 15, 12, -7, 18, 19, 2, 21, 6, 23, 24, 25 + ], + [2, 3, 2, 2]); + const axis = 2; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose( + await gradients.data(), + [12, 6, 18.7, 7, 13, 12.01, 8, 16, 40, 20, 48, 30]); + }); + + it('gradient 3D (gather) axis=2 shape=[4, 1, 4] 1D indices', async () => { + const t = tf.tensor3d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [4, 1, 4]); + const indices = tf.tensor1d([1, 3, 1], 'int32'); + const dy = + tf.tensor([2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 4, 15], [4, 1, 3]); + const axis = 2; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose( + await gradients.data(), + [0, 6, 0, -3, 0, 15.7, 0, 6, 0, 1.01, 0, 18, 0, 15, 0, 4]); + }); + + it('gradient 3D (gather) axis=2 shape=[4, 1, 4] 2D indices', async () => { + const t = tf.tensor3d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [4, 1, 4]); + const indices = tf.tensor2d([1, 3, 1], [1, 3], 'int32'); + const dy = + tf.tensor([2, -3, 4, 15, 6, 0.7, 1, 18, 0.01, 0, 4, 15], [4, 1, 1, 3]); + const axis = 2; + + const gradients = tf.grad(t => tf.gather(t, indices, axis))(t, dy); + + expect(gradients.shape).toEqual(t.shape); + expectArraysClose( + await gradients.data(), + [0, 6, 0, -3, 0, 15.7, 0, 6, 0, 1.01, 0, 18, 0, 15, 0, 4]); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 5cce7903ad0..b7e031ba2b8 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -58,6 +58,7 @@ export {expandDims} from './expand_dims'; export {eye} from './eye'; export {fill} from './fill'; export {floorDiv} from './floorDiv'; +export {gather} from './gather'; export {greater} from './greater'; export {greaterEqual} from './greater_equal'; export {imag} from './imag'; @@ -121,6 +122,7 @@ export {sub} from './sub'; export {sum} from './sum'; export {tile} from './tile'; export {truncatedNormal} from './truncated_normal'; +export {unsortedSegmentSum} from './unsorted_segment_sum'; export {unstack} from './unstack'; export {where} from './where'; export {whereAsync} from './where_async'; @@ -135,7 +137,6 @@ export * from './tensor_ops'; export * from './transpose'; export * from './softmax'; export * from './norm'; -export * from './segment_ops'; export * from './moving_average'; export * from './strided_slice'; export * from './topk'; diff --git a/tfjs-core/src/ops/scatter_nd.ts b/tfjs-core/src/ops/scatter_nd.ts index 8c0bea19e6e..cbe779731c9 100644 --- a/tfjs-core/src/ops/scatter_nd.ts +++ b/tfjs-core/src/ops/scatter_nd.ts @@ -15,10 +15,14 @@ * ============================================================================= */ -import {ENGINE} from '../engine'; +import {ENGINE, ForwardFunc} from '../engine'; +import {ScatterNd, ScatterNdAttrs, ScatterNdInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {Rank, ShapeMap, TensorLike} from '../types'; + import {op} from './operation'; import * as scatter_nd_util from './scatter_nd_util'; @@ -47,10 +51,16 @@ function scatterND_( const $updates = convertToTensor(updates, 'updates', 'scatterND'); scatter_nd_util.validateInput($updates, $indices, shape); + const forward: ForwardFunc = (backend) => { + return backend.scatterND($indices, $updates, shape); + }; + + const inputs: ScatterNdInputs = {indices: $indices, updates: $updates}; + const attrs: ScatterNdAttrs = {shape}; + return ENGINE.runKernelFunc( - backend => backend.scatterND($indices, $updates, shape), - {indices: $indices, updates: $updates}, null /* backward */, 'ScatterNd', - {shape}); + forward, inputs as {} as NamedTensorMap, null /* grad */, + ScatterNd, attrs as {} as NamedAttrMap) as Tensor; } export const scatterND = op({scatterND_}); diff --git a/tfjs-core/src/ops/segment_ops.ts b/tfjs-core/src/ops/segment_ops.ts deleted file mode 100644 index 357579ae6bb..00000000000 --- a/tfjs-core/src/ops/segment_ops.ts +++ /dev/null @@ -1,178 +0,0 @@ -/** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {ENGINE} from '../engine'; -import {Tensor, Tensor1D} from '../tensor'; -import {convertToTensor} from '../tensor_util_env'; -import {TensorLike} from '../types'; -import {assert, isInt, parseAxisParam} from '../util'; - -import {getUndoAxesPermutation} from './axis_util'; -import {expandDims} from './expand_dims'; -import {greaterEqual} from './greater_equal'; -import {logicalAnd} from './logical_and'; -import {maximum} from './maximum'; -import {op} from './operation'; -import {collectGatherOpShapeInfo} from './segment_util'; -import {ones, scalar, zerosLike} from './tensor_ops'; -import {where} from './where'; - -/** - * Computes the sum along segments of a `tf.Tensor`. - * - * ```js - * const x = tf.tensor1d([1, 2, 3, 4]); - * const segmentIds = tf.tensor1d([1, 2, 0, 1], 'int32'); - * const numSegments = 3; - * - * x.unsortedSegmentSum(segmentIds, numSegments).print() - * //or tf.unsortedSegmentSum(x, segmentIds, numSegments) - * ``` - * @param x The `tf.Tensor` that will be summed along its segments. - * @param segmentIds A `tf.Tensor1D` whose rank is equal to the rank of `x`'s - * dimension along the `axis`. Maps each element of `x` to a segment. - * @param numSegments The number of distinct `segmentIds`. - */ -/** @doc {heading: 'Operations', subheading: 'Segment'} */ -function unsortedSegmentSum_( - x: T|TensorLike, segmentIds: Tensor1D|TensorLike, numSegments: number): T { - const $x = convertToTensor(x, 'x', 'unsortedSegmentSum'); - const $segmentIds = - convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32'); - assert(isInt(numSegments), () => 'numSegments must be of dtype int'); - - const gradFunc = (dy: T, saved: Tensor[]) => { - const [$segmentIds] = saved; - const derX = () => { - return gatherDropNegatives(dy, $segmentIds as Tensor1D); - }; - return {$x: derX}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.unsortedSegmentSum($x, $segmentIds, numSegments); - save([$segmentIds]); - return res; - }, {$x}, gradFunc) as T; -} - -/** - * Gather slices from tensor `x`'s axis `axis` according to `indices`. - * - * ```js - * const x = tf.tensor1d([1, 2, 3, 4]); - * const indices = tf.tensor1d([1, 3, 3], 'int32'); - * - * x.gather(indices).print(); - * ``` - * - * ```js - * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); - * const indices = tf.tensor1d([1, 1, 0], 'int32'); - * - * x.gather(indices).print(); - * ``` - * @param x The input tensor whose slices to be gathered. - * @param indices The indices of the values to extract. - * @param axis The axis over which to select values. Defaults to 0. - */ -/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ -function gather_( - x: T|TensorLike, indices: Tensor|TensorLike, axis = 0): T { - const $x = convertToTensor(x, 'x', 'gather'); - const $indices = convertToTensor(indices, 'indices', 'gather', 'int32'); - axis = parseAxisParam(axis, $x.shape)[0]; - const shapeInfo = collectGatherOpShapeInfo($x, $indices, axis); - - const grad = (dy: T, saved: Tensor[]) => { - const [$indices] = saved; - const derX = () => { - const paramsShape = $x.shape; - const indicesSize = $indices.size; - - const outerShape = paramsShape.slice(0, axis); - const outerDims = outerShape.length; - const innerShape = paramsShape.slice(axis, paramsShape.length).slice(1); - const innerDims = innerShape.length; - - const outerAxesIndices = arrayRange(0, outerDims); - const innerAxesIndices = - arrayRange(outerDims + 1, outerDims + 1 + innerDims); - - const valuesShape = arrayConcat([outerShape, [indicesSize], innerShape]); - - const values = dy.reshape(valuesShape); - const reshapedIndices = $indices.reshape([indicesSize]); - - const transposeDims = - arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]); - const valuesTranspose = values.transpose(transposeDims); - let paramsGrad = unsortedSegmentSum( - valuesTranspose, reshapedIndices as Tensor1D, $x.shape[axis]); - - const invertTransposeDims = getUndoAxesPermutation(transposeDims); - paramsGrad = paramsGrad.transpose(invertTransposeDims); - - return paramsGrad as T; - }; - return {x: derX, indices: () => $indices}; - }; - return (ENGINE.runKernelFunc( - (backend, save) => { - const res = backend.gather($x, $indices.flatten(), axis); - save([$indices]); - return res; - }, - {x: $x, indices: $indices}, grad, 'Gather', {axis})) - .reshape(shapeInfo.outputShape); -} - -function arrayRange(start: number, stop: number): number[] { - const result = []; - for (let i = start; i < stop; ++i) { - result.push(i); - } - return result; -} - -function arrayConcat(arrays: number[][]): number[] { - const result = []; - for (let i = 0; i < arrays.length; ++i) { - for (let j = 0; j < arrays[i].length; ++j) { - result.push(arrays[i][j]); - } - } - return result; -} - -function gatherDropNegatives(x: T, indices: Tensor1D) { - // Helper function for unsorted segment ops. Gathers params for - // positive segment ids and gathers 0 for inputs with negative segment id. - // Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py - const zeroClippedIndices = maximum(indices, zerosLike(indices)); - const gathered = gather(x, zeroClippedIndices as Tensor1D); - let isPositive = greaterEqual(indices, scalar(0, 'int32')); - const numIters = gathered.rank - isPositive.rank; - for (let i = 0; i < numIters; ++i) { - isPositive = expandDims(isPositive, i + 1); - } - isPositive = logicalAnd(isPositive, ones(gathered.shape, 'bool')); - const zeroSlice = zerosLike(gathered); - return where(isPositive, gathered, zeroSlice); -} - -export const gather = op({gather_}); -export const unsortedSegmentSum = op({unsortedSegmentSum_}); diff --git a/tfjs-core/src/ops/unsorted_segment_sum.ts b/tfjs-core/src/ops/unsorted_segment_sum.ts new file mode 100644 index 00000000000..50b5b8da4f1 --- /dev/null +++ b/tfjs-core/src/ops/unsorted_segment_sum.ts @@ -0,0 +1,67 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {UnsortedSegmentSum, UnsortedSegmentSumAttrs, UnsortedSegmentSumInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor, Tensor1D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import {assert, isInt} from '../util'; + +import {op} from './operation'; + +/** + * Computes the sum along segments of a `tf.Tensor`. + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * const segmentIds = tf.tensor1d([1, 2, 0, 1], 'int32'); + * const numSegments = 3; + * + * x.unsortedSegmentSum(segmentIds, numSegments).print() + * //or tf.unsortedSegmentSum(x, segmentIds, numSegments) + * ``` + * @param x The `tf.Tensor` that will be summed along its segments. + * @param segmentIds A `tf.Tensor1D` whose rank is equal to the rank of `x`'s + * dimension along the `axis`. Maps each element of `x` to a segment. + * @param numSegments The number of distinct `segmentIds`. + */ +/** @doc {heading: 'Operations', subheading: 'Segment'} */ +function unsortedSegmentSum_( + x: T|TensorLike, segmentIds: Tensor1D|TensorLike, numSegments: number): T { + const $x = convertToTensor(x, 'x', 'unsortedSegmentSum'); + const $segmentIds = + convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32'); + assert(isInt(numSegments), () => 'numSegments must be of dtype int'); + + const inputs: UnsortedSegmentSumInputs = {x: $x, segmentIds: $segmentIds}; + const attrs: UnsortedSegmentSumAttrs = {numSegments}; + + const forward: ForwardFunc = (backend, save) => { + const res = backend.unsortedSegmentSum($x, $segmentIds, numSegments); + save([$segmentIds]); + return res; + }; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, + UnsortedSegmentSum, attrs as {} as NamedAttrMap) as T; +} + +export const unsortedSegmentSum = op({unsortedSegmentSum_}); diff --git a/tfjs-core/src/ops/segment_ops_test.ts b/tfjs-core/src/ops/unsorted_segment_sum_test.ts similarity index 100% rename from tfjs-core/src/ops/segment_ops_test.ts rename to tfjs-core/src/ops/unsorted_segment_sum_test.ts diff --git a/tfjs-core/src/public/chained_ops/gather.ts b/tfjs-core/src/public/chained_ops/gather.ts index dc444d252f8..8f9ffeaec73 100644 --- a/tfjs-core/src/public/chained_ops/gather.ts +++ b/tfjs-core/src/public/chained_ops/gather.ts @@ -15,8 +15,7 @@ * ============================================================================= */ -// TODO update import path once op is modularized. -import {gather} from '../../ops/ops'; +import {gather} from '../../ops/gather'; import {Tensor} from '../../tensor'; import {Rank, TensorLike} from '../../types'; diff --git a/tfjs-core/src/public/chained_ops/unsorted_segment_sum.ts b/tfjs-core/src/public/chained_ops/unsorted_segment_sum.ts index 01881dbbae4..1601f4efd5b 100644 --- a/tfjs-core/src/public/chained_ops/unsorted_segment_sum.ts +++ b/tfjs-core/src/public/chained_ops/unsorted_segment_sum.ts @@ -15,8 +15,7 @@ * ============================================================================= */ -// TODO update import path once op is modularized. -import {unsortedSegmentSum} from '../../ops/ops'; +import {unsortedSegmentSum} from '../../ops/unsorted_segment_sum'; import {Tensor, Tensor1D} from '../../tensor'; import {Rank, TensorLike1D} from '../../types'; diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index f3968bdb284..99ad871cef1 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -35,6 +35,7 @@ import {divGradConfig} from './gradients/Div_grad'; import {eluGradConfig} from './gradients/Elu_grad'; import {floorDivGradConfig} from './gradients/FloorDiv_grad'; import {fusedBatchNormGradConfig} from './gradients/FusedBatchNorm_grad'; +import {gatherGradConfig} from './gradients/GatherV2_grad'; import {greaterEqualGradConfig} from './gradients/GreaterEqual_grad'; import {identityGradConfig} from './gradients/Identity_grad'; import {lrnGradConfig} from './gradients/LRN_grad'; @@ -67,6 +68,7 @@ import {sumGradConfig} from './gradients/Sum_grad'; import {tileGradConfig} from './gradients/Tile_grad'; import {transposeGradConfig} from './gradients/Transpose_grad'; import {unpackGradConfig} from './gradients/Unpack_grad'; +import {unsortedSegmentSumGradConfig} from './gradients/UnsortedSegmentSum_grad'; import {GradConfig} from './kernel_registry'; import {registerGradient} from './kernel_registry'; @@ -93,6 +95,7 @@ const gradConfigs: GradConfig[] = [ eluGradConfig, floorDivGradConfig, fusedBatchNormGradConfig, + gatherGradConfig, greaterEqualGradConfig, identityGradConfig, lrnGradConfig, @@ -129,7 +132,8 @@ const gradConfigs: GradConfig[] = [ sumGradConfig, tileGradConfig, transposeGradConfig, - unpackGradConfig + unpackGradConfig, + unsortedSegmentSumGradConfig ]; for (const gradientConfig of gradConfigs) { diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index f330796dac6..4f5ee7d939a 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -87,6 +87,7 @@ import './ops/expand_dims_test'; import './ops/eye_test'; import './ops/fused_test'; import './ops/gather_nd_test'; +import './ops/gather_test'; import './ops/gram_schmidt_test'; import './ops/greater_equal_test'; import './ops/greater_test'; @@ -139,7 +140,6 @@ import './ops/reverse_3d_test'; import './ops/reverse_4d_test'; import './ops/reverse_test'; import './ops/scatter_nd_test'; -import './ops/segment_ops_test'; import './ops/selu_test'; import './ops/sigmoid_cross_entropy_test'; import './ops/signal_ops_test'; @@ -159,6 +159,7 @@ import './ops/topk_test'; import './ops/transpose_test'; import './ops/truncated_normal_test'; import './ops/unary_ops_test'; +import './ops/unsorted_segment_sum_test'; import './ops/unstack_test'; import './ops/where_async_test'; import './ops/where_test';