diff --git a/tfjs-core/src/gradients/BatchToSpaceND_grad.ts b/tfjs-core/src/gradients/BatchToSpaceND_grad.ts index 68e5bde2557..5188c9ee0cf 100644 --- a/tfjs-core/src/gradients/BatchToSpaceND_grad.ts +++ b/tfjs-core/src/gradients/BatchToSpaceND_grad.ts @@ -17,7 +17,7 @@ import {BatchToSpaceND, BatchToSpaceNDAttrs} from '../kernel_names'; import {GradConfig, NamedAttrMap} from '../kernel_registry'; -import {spaceToBatchND} from '../ops/array_ops'; +import {spaceToBatchND} from '../ops/space_to_batch_nd'; import {Tensor} from '../tensor'; export const batchToSpaceNDGradConfig: GradConfig = { diff --git a/tfjs-core/src/gradients/SpaceToBatchND_grad.ts b/tfjs-core/src/gradients/SpaceToBatchND_grad.ts new file mode 100644 index 00000000000..af45aeebfc7 --- /dev/null +++ b/tfjs-core/src/gradients/SpaceToBatchND_grad.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {SpaceToBatchND, SpaceToBatchNDAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {batchToSpaceND} from '../ops/batch_to_space_nd'; +import {Tensor} from '../tensor'; + +export const spaceToBatchNDGradConfig: GradConfig = { + kernelName: SpaceToBatchND, + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const {blockShape, paddings} = attrs as {} as SpaceToBatchNDAttrs; + return {x: () => batchToSpaceND(dy, blockShape, paddings)}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 06b61ec7915..2a0af5b43a9 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -221,6 +221,13 @@ export interface PadV2Attrs { constantValue: number; } +export const SpaceToBatchND = 'SpaceToBatchND'; +export type SpaceToBatchNDInputs = Pick; +export interface SpaceToBatchNDAttrs { + blockShape: number[]; + paddings: number[][]; +} + export const SplitV = 'SplitV'; export type SplitVInputs = Pick; export interface SplitVAttrs { diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index 5c6ec4a2a0d..48cab082d49 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -159,91 +159,6 @@ function stack_( return concat(expandedTensors, axis); } -/** - * This operation divides "spatial" dimensions `[1, ..., M]` of the input into - * a grid of blocks of shape `blockShape`, and interleaves these blocks with - * the "batch" dimension (0) such that in the output, the spatial - * dimensions `[1, ..., M]` correspond to the position within the grid, - * and the batch dimension combines both the position within a spatial block - * and the original batch position. Prior to division into blocks, - * the spatial dimensions of the input are optionally zero padded - * according to `paddings`. See below for a precise description. - * - * ```js - * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); - * const blockShape = [2, 2]; - * const paddings = [[0, 0], [0, 0]]; - * - * x.spaceToBatchND(blockShape, paddings).print(); - * ``` - * - * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape + - * remainingShape`, where spatialShape has `M` dimensions. - * @param blockShape A 1-D array. Must have shape `[M]`, all values must - * be >= 1. - * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >= - * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad - * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It - * is required that - * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0` - * - * This operation is equivalent to the following steps: - * - * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input - * according to `paddings` to produce `padded` of shape paddedShape. - * - * 2. Reshape `padded` to `reshapedPadded` of shape: - * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ..., - * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape` - * - * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded` - * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ..., - * paddedShape[M] / blockShape[M-1]] + remainingShape` - * - * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the - * batch dimension, producing an output tensor of shape: - * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ..., - * paddedShape[M] / blockShape[M-1]] + remainingShape` - */ -/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ -function spaceToBatchND_( - x: T|TensorLike, blockShape: number[], paddings: number[][]): T { - const $x = convertToTensor(x, 'x', 'spaceToBatchND'); - - util.assert( - $x.rank >= 1 + blockShape.length, - () => `input rank ${$x.rank} should be > than [blockShape] ${ - blockShape.length}`); - - util.assert( - paddings.length === blockShape.length, - () => `paddings.shape[0] ${ - paddings.length} must be equal to [blockShape] ${blockShape.length}`); - - util.assert( - $x.shape.reduce( - (a, b, i) => { - if (i > 0 && i <= blockShape.length) { - return a && - ((b + paddings[i - 1][0] + paddings[i - 1][1]) % - blockShape[i - 1] === - 0); - } - return a; - }, - true), - () => `input spatial dimensions ${$x.shape.slice(1)} with paddings ${ - paddings.toString()} must be divisible by blockShapes ${ - blockShape.toString()}`); - - const grad = (dy: T) => { - return {$x: () => dy.batchToSpaceND(blockShape, paddings) as T}; - }; - - return ENGINE.runKernelFunc( - backend => backend.spaceToBatchND($x, blockShape, paddings), {$x}, grad); -} - /** * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s. * @@ -548,7 +463,6 @@ export const cumsum = op({cumsum_}); export const depthToSpace = op({depthToSpace_}); export const expandDims = op({expandDims_}); export const reshape = op({reshape_}); -export const spaceToBatchND = op({spaceToBatchND_}); export const squeeze = op({squeeze_}); export const stack = op({stack_}); export const unstack = op({unstack_}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 9cc1e31d9f2..64db94702d9 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -63,6 +63,7 @@ export {randomGamma} from './random_gamma'; export {randomNormal} from './random_normal'; export {randomUniform} from './random_uniform'; export {separableConv2d} from './separable_conv2d'; +export {spaceToBatchND} from './space_to_batch_nd'; export {split} from './split'; export {square} from './square'; export {squaredDifference} from './squared_difference'; diff --git a/tfjs-core/src/ops/pool.ts b/tfjs-core/src/ops/pool.ts index 6063385838f..78644a62581 100644 --- a/tfjs-core/src/ops/pool.ts +++ b/tfjs-core/src/ops/pool.ts @@ -22,10 +22,10 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {spaceToBatchND} from './array_ops'; import {batchToSpaceND} from './batch_to_space_nd'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {spaceToBatchND} from './space_to_batch_nd'; /** * Computes the 2D max pooling of an image. diff --git a/tfjs-core/src/ops/space_to_batch_nd.ts b/tfjs-core/src/ops/space_to_batch_nd.ts new file mode 100644 index 00000000000..a490f20d966 --- /dev/null +++ b/tfjs-core/src/ops/space_to_batch_nd.ts @@ -0,0 +1,117 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {SpaceToBatchND, SpaceToBatchNDAttrs, SpaceToBatchNDInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {op} from './operation'; + +/** + * This operation divides "spatial" dimensions `[1, ..., M]` of the input into + * a grid of blocks of shape `blockShape`, and interleaves these blocks with + * the "batch" dimension (0) such that in the output, the spatial + * dimensions `[1, ..., M]` correspond to the position within the grid, + * and the batch dimension combines both the position within a spatial block + * and the original batch position. Prior to division into blocks, + * the spatial dimensions of the input are optionally zero padded + * according to `paddings`. See below for a precise description. + * + * ```js + * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); + * const blockShape = [2, 2]; + * const paddings = [[0, 0], [0, 0]]; + * + * x.spaceToBatchND(blockShape, paddings).print(); + * ``` + * + * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape + + * remainingShape`, where spatialShape has `M` dimensions. + * @param blockShape A 1-D array. Must have shape `[M]`, all values must + * be >= 1. + * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >= + * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad + * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It + * is required that + * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0` + * + * This operation is equivalent to the following steps: + * + * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input + * according to `paddings` to produce `padded` of shape paddedShape. + * + * 2. Reshape `padded` to `reshapedPadded` of shape: + * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ..., + * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape` + * + * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded` + * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ..., + * paddedShape[M] / blockShape[M-1]] + remainingShape` + * + * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the + * batch dimension, producing an output tensor of shape: + * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ..., + * paddedShape[M] / blockShape[M-1]] + remainingShape` + */ +/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ +function spaceToBatchND_( + x: T|TensorLike, blockShape: number[], paddings: number[][]): T { + const $x = convertToTensor(x, 'x', 'spaceToBatchND'); + + util.assert( + $x.rank >= 1 + blockShape.length, + () => `input rank ${$x.rank} should be > than [blockShape] ${ + blockShape.length}`); + + util.assert( + paddings.length === blockShape.length, + () => `paddings.shape[0] ${ + paddings.length} must be equal to [blockShape] ${blockShape.length}`); + + util.assert( + $x.shape.reduce( + (a, b, i) => { + if (i > 0 && i <= blockShape.length) { + return a && + ((b + paddings[i - 1][0] + paddings[i - 1][1]) % + blockShape[i - 1] === + 0); + } + return a; + }, + true), + () => `input spatial dimensions ${$x.shape.slice(1)} with paddings ${ + paddings.toString()} must be divisible by blockShapes ${ + blockShape.toString()}`); + + const forward: ForwardFunc = backend => + backend.spaceToBatchND($x, blockShape, paddings); + + const inputs: SpaceToBatchNDInputs = {x: $x}; + const attrs: SpaceToBatchNDAttrs = {blockShape, paddings}; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* gradient */, + SpaceToBatchND, attrs as {} as NamedAttrMap); +} + +export const spaceToBatchND = op({spaceToBatchND_}); diff --git a/tfjs-core/src/ops/space_to_batch_nd_test.ts b/tfjs-core/src/ops/space_to_batch_nd_test.ts new file mode 100644 index 00000000000..11da95ac170 --- /dev/null +++ b/tfjs-core/src/ops/space_to_batch_nd_test.ts @@ -0,0 +1,268 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('spaceToBatchND', ALL_ENVS, () => { + it('tensor4d, input shape=[1, 2, 2, 1], blockShape=[2, 2]', async () => { + const t = tf.tensor4d([[[[1], [2]], [[3], [4]]]], [1, 2, 2, 1]); + const blockShape = [2, 2]; + const paddings = [[0, 0], [0, 0]]; + + const res = tf.spaceToBatchND(t, blockShape, paddings); + expect(res.shape).toEqual([4, 1, 1, 1]); + expectArraysClose(await res.data(), [1, 2, 3, 4]); + }); + + it('tensor4d, input shape=[1, 2, 2, 3], blockShape=[2, 2]', async () => { + const t = tf.tensor4d( + [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], [1, 2, 2, 3]); + const blockShape = [2, 2]; + const paddings = [[0, 0], [0, 0]]; + + const res = tf.spaceToBatchND(t, blockShape, paddings); + expect(res.shape).toEqual([4, 1, 1, 3]); + expectArraysClose( + await res.data(), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + }); + + it('tensor4d, input shape=[1, 4, 4, 1], blockShape=[2, 2]', async () => { + const t = tf.tensor4d( + [[ + [[1], [2], [3], [4]], [[5], [6], [7], [8]], [[9], [10], [11], [12]], + [[13], [14], [15], [16]] + ]], + [1, 4, 4, 1]); + const blockShape = [2, 2]; + const paddings = [[0, 0], [0, 0]]; + + const res = tf.spaceToBatchND(t, blockShape, paddings); + expect(res.shape).toEqual([4, 2, 2, 1]); + expectArraysClose( + await res.data(), + [1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16]); + }); + + it('tensor4d, input shape=[2, 6, 6, 1], blockShape=[2, 2]', async () => { + const t = tf.tensor4d( + [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72 + ], + [2, 6, 6, 1]); + const blockShape = [2, 2]; + const paddings = [[0, 0], [0, 0]]; + + const res = tf.spaceToBatchND(t, blockShape, paddings); + expect(res.shape).toEqual([8, 3, 3, 1]); + expectArraysClose(await res.data(), [ + 1, 3, 5, 13, 15, 17, 25, 27, 29, 37, 39, 41, 49, 51, 53, 61, 63, 65, + 2, 4, 6, 14, 16, 18, 26, 28, 30, 38, 40, 42, 50, 52, 54, 62, 64, 66, + 7, 9, 11, 19, 21, 23, 31, 33, 35, 43, 45, 47, 55, 57, 59, 67, 69, 71, + 8, 10, 12, 20, 22, 24, 32, 34, 36, 44, 46, 48, 56, 58, 60, 68, 70, 72 + ]); + }); + + it('tensor4d, input shape=[2, 2, 4, 1], blockShape=[2, 2]', async () => { + const t = tf.tensor4d( + [ + [[[1], [2], [3], [4]], [[5], [6], [7], [8]]], + [[[9], [10], [11], [12]], [[13], [14], [15], [16]]] + ], + [2, 2, 4, 1]); + const blockShape = [2, 2]; + const paddings = [[0, 0], [2, 0]]; + + const res = tf.spaceToBatchND(t, blockShape, paddings); + expect(res.shape).toEqual([8, 1, 3, 1]); + expectArraysClose(await res.data(), [ + 0, 1, 3, 0, 9, 11, 0, 2, 4, 0, 10, 12, + 0, 5, 7, 0, 13, 15, 0, 6, 8, 0, 14, 16 + ]); + }); + + it('tensor2d, blockShape [2]', async () => { + const t = tf.tensor2d([1, 3, 2, 4], [1, 4]); + const blockShape = [2]; + const paddings = [[0, 0]]; + + const res = tf.spaceToBatchND(t, blockShape, paddings); + expect(res.shape).toEqual([2, 2]); + expectArraysClose(await res.data(), [1, 2, 3, 4]); + }); + + it('throws when blockShape equal to input rank', () => { + const t = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); + const blockShape = [2, 2, 2, 2]; + const paddings = [[0, 0], [0, 0], [0, 0], [0, 0]]; + + expect(() => tf.spaceToBatchND(t, blockShape, paddings)) + .toThrowError('input rank 4 should be > than [blockShape] 4'); + }); + + it('throws when paddings row dimension not equal to blockshape', () => { + const t = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); + const blockShape = [2, 2]; + const paddings = [[0, 0]]; + + expect(() => tf.spaceToBatchND(t, blockShape, paddings)) + .toThrowError('paddings.shape[0] 1 must be equal to [blockShape] 2'); + }); + + it('throws when input tensor spatial dimension not divisible by blockshapes', + () => { + const t = tf.tensor4d([1, 2, 3, 4, 5, 6], [1, 2, 3, 1]); + const blockShape = [2, 2]; + const paddings = [[0, 0], [0, 0]]; + + expect(() => tf.spaceToBatchND(t, blockShape, paddings)) + .toThrowError( + 'input spatial dimensions 2,3,1 with paddings 0,0,0,0 must be ' + + 'divisible by blockShapes 2,2'); + }); + + it('accepts a tensor-like object', async () => { + const t = [[[[1], [2]], [[3], [4]]]]; + const blockShape = [2, 2]; + const paddings = [[0, 0], [0, 0]]; + + const res = tf.spaceToBatchND(t, blockShape, paddings); + expect(res.shape).toEqual([4, 1, 1, 1]); + expectArraysClose(await res.data(), [1, 2, 3, 4]); + }); +}); + +describeWithFlags('batchToSpaceND X spaceToBatchND', ALL_ENVS, () => { + it('tensor4d, input shape=[4, 1, 1, 1], blockShape=[2, 2]', async () => { + const t = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]); + const blockShape = [2, 2]; + const crops = [[0, 0], [0, 0]]; + const paddings = [[0, 0], [0, 0]]; + + const b2s = tf.batchToSpaceND(t, blockShape, crops); + expect(b2s.shape).toEqual([1, 2, 2, 1]); + expectArraysClose(await b2s.data(), [1, 2, 3, 4]); + + const s2b = tf.spaceToBatchND(b2s, blockShape, paddings); + expect(s2b.shape).toEqual([4, 1, 1, 1]); + expectArraysClose(await s2b.data(), [1, 2, 3, 4]); + }); + + it('tensor4d, input shape=[2, 6, 6, 1], blockShape=[2, 2]', async () => { + const t = tf.tensor4d( + [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72 + ], + [2, 6, 6, 1]); + const blockShape = [2, 2]; + const crops = [[0, 0], [0, 0]]; + const paddings = [[0, 0], [0, 0]]; + + const s2b = tf.spaceToBatchND(t, blockShape, paddings); + expect(s2b.shape).toEqual([8, 3, 3, 1]); + expectArraysClose(await s2b.data(), [ + 1, 3, 5, 13, 15, 17, 25, 27, 29, 37, 39, 41, 49, 51, 53, 61, 63, 65, + 2, 4, 6, 14, 16, 18, 26, 28, 30, 38, 40, 42, 50, 52, 54, 62, 64, 66, + 7, 9, 11, 19, 21, 23, 31, 33, 35, 43, 45, 47, 55, 57, 59, 67, 69, 71, + 8, 10, 12, 20, 22, 24, 32, 34, 36, 44, 46, 48, 56, 58, 60, 68, 70, 72 + ]); + + const b2s = tf.batchToSpaceND(s2b, blockShape, crops); + expect(b2s.shape).toEqual([2, 6, 6, 1]); + expectArraysClose(await b2s.data(), [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72 + ]); + }); + + it('gradients, input shape=[4, 2, 2], block shape=[2]', async () => { + const t = tf.tensor( + [-61, 37, -68, 72, 31, 62, 0, -13, 28, 54, 96, 44, -55, -64, -88, -94], + [4, 2, 2]); + const blockShape = [2]; + const paddings = [[0, 2]]; + const dy = tf.tensor( + [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 + ], + [8, 2, 2]); + + const gradient = + tf.grad(t => tf.spaceToBatchND(t, blockShape, paddings))(t, dy); + expect(gradient.shape).toEqual([4, 2, 2]); + expectArraysClose( + await gradient.data(), + [1, 2, 17, 18, 5, 6, 21, 22, 9, 10, 25, 26, 13, 14, 29, 30]); + }); + + it('gradient with clones input=[4, 2, 2], block shape=[2]', async () => { + const t = tf.tensor( + [-61, 37, -68, 72, 31, 62, 0, -13, 28, 54, 96, 44, -55, -64, -88, -94], + [4, 2, 2]); + const blockShape = [2]; + const paddings = [[0, 2]]; + const dy = tf.tensor( + [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 + ], + [8, 2, 2]); + + const gradient = tf.grad( + t => tf.spaceToBatchND(t.clone(), blockShape, paddings).clone())(t, dy); + expect(gradient.shape).toEqual([4, 2, 2]); + expectArraysClose( + await gradient.data(), + [1, 2, 17, 18, 5, 6, 21, 22, 9, 10, 25, 26, 13, 14, 29, 30]); + }); + + it('gradients, input shape=[2, 2, 4, 1], block shape=[2, 2]', async () => { + const t = tf.tensor4d( + [ + [[[1], [2], [3], [4]], [[5], [6], [7], [8]]], + [[[9], [10], [11], [12]], [[13], [14], [15], [16]]] + ], + [2, 2, 4, 1]); + const blockShape = [2, 2]; + const paddings = [[0, 0], [2, 0]]; + const dy = tf.tensor( + [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 + ], + [8, 1, 3, 1]); + + const gradient = + tf.grad(t => tf.spaceToBatchND(t, blockShape, paddings))(t, dy); + expect(gradient.shape).toEqual([2, 2, 4, 1]); + expectArraysClose( + await gradient.data(), + [2, 8, 3, 9, 14, 20, 15, 21, 5, 11, 6, 12, 17, 23, 18, 24]); + }); +}); diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index e8a285a4954..ba785603de0 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -41,6 +41,7 @@ import './pad'; import './separable_conv2d'; import './split'; import './squared_difference'; +import './space_to_batch_nd'; import './sub'; import './tile'; import './transpose'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 67cb0ce55e6..2378927d7b1 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -49,6 +49,7 @@ const CHAINED_OPS = [ 'pad', 'max', 'separableConv2d', + 'spaceToBatchND', 'split', 'square', 'sub', diff --git a/tfjs-core/src/public/chained_ops/space_to_batch_nd.ts b/tfjs-core/src/public/chained_ops/space_to_batch_nd.ts new file mode 100644 index 00000000000..a9c84bf6b04 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/space_to_batch_nd.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {spaceToBatchND} from '../../ops/space_to_batch_nd'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + spaceToBatchND(blockShape: number[], paddings: number[][]): + Tensor; + } +} + +Tensor.prototype.spaceToBatchND = function( + blockShape: number[], paddings: number[][]): Tensor { + this.throwIfDisposed(); + return spaceToBatchND(this, blockShape, paddings); +}; diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 00248c5e305..39ad8d0bdaf 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -32,6 +32,7 @@ import {lrnGradConfig} from './gradients/LRN_grad'; import {maxGradConfig} from './gradients/Max_grad'; import {oneHotGradConfig} from './gradients/OneHot_grad'; import {padV2GradConfig} from './gradients/PadV2_grad'; +import {spaceToBatchNDGradConfig} from './gradients/SpaceToBatchND_grad'; import {splitVGradConfig} from './gradients/SplitV_grad'; import {squareGradConfig} from './gradients/Square_grad'; import {squaredDifferenceGradConfig} from './gradients/SquaredDifference_grad'; @@ -43,29 +44,18 @@ import {registerGradient} from './kernel_registry'; // Export all kernel configs here so that the package can auto register them const gradConfigs: GradConfig[] = [ - addGradConfig, - addNGradConfig, - batchMatMulGradConfig, - batchToSpaceNDGradConfig, - broadcastToGradConfig, - concatGradConfig, - conv2DGradConfig, - conv2DBackpropInputGradConfig, - conv3DGradConfig, - depthwiseConv2dNativeGradConfig, - divGradConfig, - fusedBatchNormGradConfig, - greaterEqualGradConfig, - identityGradConfig, - lrnGradConfig, - oneHotGradConfig, - padV2GradConfig, - splitVGradConfig, - maxGradConfig, - squareGradConfig, - squaredDifferenceGradConfig, - tileGradConfig, - transposeGradConfig, + addGradConfig, addNGradConfig, + batchMatMulGradConfig, batchToSpaceNDGradConfig, + broadcastToGradConfig, concatGradConfig, + conv2DGradConfig, conv2DBackpropInputGradConfig, + conv3DGradConfig, depthwiseConv2dNativeGradConfig, + divGradConfig, fusedBatchNormGradConfig, + greaterEqualGradConfig, identityGradConfig, + lrnGradConfig, oneHotGradConfig, + padV2GradConfig, splitVGradConfig, + maxGradConfig, spaceToBatchNDGradConfig, + squareGradConfig, squaredDifferenceGradConfig, + tileGradConfig, transposeGradConfig, subGradConfig ]; diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 446c83b3c01..8ccfe74214f 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -295,8 +295,6 @@ export interface OpHandler { strides?: [number, number]|number): T; unsortedSegmentSum( x: T, segmentIds: Tensor1D|TensorLike1D, numSegments: number): T; - spaceToBatchND( - x: T, blockShape: number[], paddings: number[][]): T; topk(x: T, k: number, sorted: boolean): {values: T, indices: T}; stridedSlice( @@ -1182,12 +1180,6 @@ export class Tensor { return opHandler.unsortedSegmentSum(this, segmentIds, numSegments); } - spaceToBatchND( - this: T, blockShape: number[], paddings: number[][]): T { - this.throwIfDisposed(); - return opHandler.spaceToBatchND(this, blockShape, paddings); - } - topk(this: T, k = 1, sorted = true): {values: T, indices: T} { this.throwIfDisposed(); diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index 2b8348169cc..b73e7e87338 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -103,6 +103,7 @@ import './ops/signal_ops_test'; import './ops/slice_test'; import './ops/slice_util_test'; import './ops/softmax_test'; +import './ops/space_to_batch_nd_test'; import './ops/sparse_to_dense_test'; import './ops/spectral_ops_test'; import './ops/strided_slice_test';