Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tfjs-core/src/gradients/BatchToSpaceND_grad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {BatchToSpaceND, BatchToSpaceNDAttrs} from '../kernel_names';
import {GradConfig, NamedAttrMap} from '../kernel_registry';
import {spaceToBatchND} from '../ops/array_ops';
import {Tensor} from '../tensor';

export const batchToSpaceNDGradConfig: GradConfig = {
kernelName: BatchToSpaceND,
gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {
const {blockShape, crops} = attrs as {} as BatchToSpaceNDAttrs;
return {x: () => spaceToBatchND(dy, blockShape, crops)};
}
};
7 changes: 7 additions & 0 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ export interface BatchMatMulAttrs {
transposeB: boolean;
}

export const BatchToSpaceND = 'BatchToSpaceND';
export type BatchToSpaceNDInputs = Pick<NamedTensorInfoMap, 'x'>;
export interface BatchToSpaceNDAttrs {
blockShape: number[];
crops: number[][];
}

export type BinaryInputs = Pick<NamedTensorInfoMap, 'a'|'b'>;

export const BroadcastTo = 'BroadcastTo';
Expand Down
80 changes: 1 addition & 79 deletions tfjs-core/src/ops/array_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,83 +159,6 @@ function stack_<T extends Tensor>(
return concat(expandedTensors, axis);
}

/**
* This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
* shape `blockShape + [batch]`, interleaves these blocks back into the grid
* defined by the spatial dimensions `[1, ..., M]`, to obtain a result with
* the same rank as the input. The spatial dimensions of this intermediate
* result are then optionally cropped according to `crops` to produce the
* output. This is the reverse of `tf.spaceToBatchND`. See below for a precise
* description.
*
* ```js
* const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
* const blockShape = [2, 2];
* const crops = [[0, 0], [0, 0]];
*
* x.batchToSpaceND(blockShape, crops).print();
* ```
*
* @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
* remainingShape`, where spatialShape has `M` dimensions.
* @param blockShape A 1-D array. Must have shape `[M]`, all values must
* be >= 1.
* @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0.
* `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input
* dimension `i + 1`, which corresponds to spatial dimension `i`. It is required
* that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]`
*
* This operation is equivalent to the following steps:
*
* 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ...,
* blockShape[M-1], batch / prod(blockShape), x.shape[1], ...,
* x.shape[N-1]]`
*
* 2. Permute dimensions of `reshaped`to produce `permuted` of shape `[batch /
* prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M],
* blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
*
* 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch /
* prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] *
* blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
*
* 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted`
* according to `crops` to produce the output of shape: `[batch /
* prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1],
* ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] -
* crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]`
*/
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function batchToSpaceND_<T extends Tensor>(
x: T|TensorLike, blockShape: number[], crops: number[][]): T {
const $x = convertToTensor(x, 'x', 'batchToSpaceND');
const prod = blockShape.reduce((a, b) => a * b);

util.assert(
$x.rank >= 1 + blockShape.length,
() => `input rank is ${$x.rank} but should be > than blockShape.length ${
blockShape.length}`);

util.assert(
crops.length === blockShape.length,
() => `crops.length is ${
crops.length} but should be equal to blockShape.length ${
blockShape.length}`);

util.assert(
$x.shape[0] % prod === 0,
() => `input tensor batch is ${
$x.shape[0]} but is not divisible by the product of ` +
`the elements of blockShape ${blockShape.join(' * ')} === ${prod}`);

const grad = (dy: T) => {
return {$x: () => dy.spaceToBatchND(blockShape, crops)};
};

return ENGINE.runKernelFunc(
backend => backend.batchToSpaceND($x, blockShape, crops), {$x}, grad);
}

/**
* This operation divides "spatial" dimensions `[1, ..., M]` of the input into
* a grid of blocks of shape `blockShape`, and interleaves these blocks with
Expand Down Expand Up @@ -314,7 +237,7 @@ function spaceToBatchND_<T extends Tensor>(
blockShape.toString()}`);

const grad = (dy: T) => {
return {$x: () => dy.batchToSpaceND(blockShape, paddings)};
return {$x: () => dy.batchToSpaceND(blockShape, paddings) as T};
};

return ENGINE.runKernelFunc(
Expand Down Expand Up @@ -620,7 +543,6 @@ export {
print // Not wrapped in op() since no need to increase stack trace.
};

export const batchToSpaceND = op({batchToSpaceND_});
export const cast = op({cast_});
export const cumsum = op({cumsum_});
export const depthToSpace = op({depthToSpace_});
Expand Down
194 changes: 0 additions & 194 deletions tfjs-core/src/ops/array_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2852,200 +2852,6 @@ describeWithFlags('cumsum', ALL_ENVS, () => {
});
});

describeWithFlags('batchToSpaceND', ALL_ENVS, () => {
it('tensor4d, input shape=[4, 1, 1, 1], blockShape=[2, 2]', async () => {
const t = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([1, 2, 2, 1]);
expectArraysClose(await res.data(), [1, 2, 3, 4]);
});

it('tensor4d, input shape=[4, 1, 1, 3], blockShape=[2, 2]', async () => {
const t =
tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [4, 1, 1, 3]);
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([1, 2, 2, 3]);
expectArraysClose(
await res.data(), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
});

it('tensor4d, input shape=[4, 2, 2, 1], blockShape=[2, 2]', async () => {
const t = tf.tensor4d(
[1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16], [4, 2, 2, 1]);
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([1, 4, 4, 1]);
expectArraysClose(
await res.data(),
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
});

it('tensor4d, input shape=[8, 1, 3, 1], blockShape=[2, 2]', async () => {
const t = tf.tensor4d(
[
0, 1, 3, 0, 9, 11, 0, 2, 4, 0, 10, 12,
0, 5, 7, 0, 13, 15, 0, 6, 8, 0, 14, 16
],
[8, 1, 3, 1]);
const blockShape = [2, 2];
const crops = [[0, 0], [2, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([2, 2, 4, 1]);
expectArraysClose(
await res.data(),
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
});

it('tensor2d, blockShape [1]', async () => {
const t = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const blockShape = [2];
const crops = [[0, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([1, 4]);
expectArraysClose(await res.data(), [1, 3, 2, 4]);
});

it('tensor3d, blockSHape [1]', async () => {
const t = tf.tensor(
[
-61, 37, -68, 72, 31, 62, 0, -13, 28, 54, 96,
44, -55, -64, -88, -94, 65, -32, -96, -73, -2, -77,
-14, 47, 33, 15, 70, 20, 75, 28, 84, -13
],
[8, 2, 2]);
const blockShape = [2];
const crops = [[0, 2]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([4, 2, 2]);
expectArraysClose(
await res.data(),
[-61, 37, 65, -32, 31, 62, -2, -77, 28, 54, 33, 15, -55, -64, 75, 28]);
});

it('tensor3d, blockShape [2]', async () => {
const t = tf.tensor(
[
-61, 37, -68, 72, 31, 62, 0, -13, 28, 54, 96,
44, -55, -64, -88, -94, 65, -32, -96, -73, -2, -77,
-14, 47, 33, 15, 70, 20, 75, 28, 84, -13
],
[8, 2, 2]);
const blockShape = [2, 2];
const crops = [[2, 0], [2, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([2, 2, 2]);
expectArraysClose(await res.data(), [72, 44, -73, 20, -13, -94, 47, -13]);
});

it('throws when blockShape equal to input rank', () => {
const t = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
const blockShape = [2, 2, 2, 2];
const crops = [[0, 0], [0, 0], [0, 0], [0, 0]];

expect(() => tf.batchToSpaceND(t, blockShape, crops))
.toThrowError(
`input rank is ${t.rank} but should be > than blockShape.length ${
blockShape.length}`);
});

it('throws when crops row dimension not equal to blockshape', () => {
const t = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
const blockShape = [2, 2];
const crops = [[0, 0]];

expect(() => tf.batchToSpaceND(t, blockShape, crops))
.toThrowError(`crops.length is ${
crops.length} but should be equal to blockShape.length ${
blockShape.length}`);
});

it('throws when input tensor batch not divisible by prod(blockShape)', () => {
const t = tf.tensor4d([1, 2, 3, 4, 5], [5, 1, 1, 1]);
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];
const prod = blockShape.reduce((a, b) => a * b);

expect(() => tf.batchToSpaceND(t, blockShape, crops))
.toThrowError(
`input tensor batch is ${t.shape[0]} but is not divisible by the ` +
`product of the elements of blockShape ${
blockShape.join(' * ')} === ${prod}`);
});

it('accepts a tensor-like object', async () => {
const t = [[[[1]]], [[[2]]], [[[3]]], [[[4]]]];
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([1, 2, 2, 1]);
expectArraysClose(await res.data(), [1, 2, 3, 4]);
});

it('gradients, input shape=[4, 2, 2], block shape=[2]', async () => {
const t = tf.tensor(
[-61, 37, -68, 72, 31, 62, 0, -13, 28, 54, 96, 44, -55, -64, -88, -94],
[4, 2, 2]);
const blockShape = [2];
const crops = [[0, 2]];
const dy = tf.tensor([.01, .02, .03, .04, .05, .06, .07, .08], [2, 2, 2]);

const gradient =
tf.grad(t => tf.batchToSpaceND(t, blockShape, crops))(t, dy);
expect(gradient.shape).toEqual([4, 2, 2]);
expectArraysClose(await gradient.data(), [
0.01, 0.02, 0, 0, 0.05, 0.06, 0, 0, 0.03, 0.04, 0, 0, 0.07, 0.08, 0, 0
]);
});

it('gradients, input shape=[4, 2, 2, 1], block shape=[2, 2]', async () => {
const t = tf.tensor4d(
[1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16], [4, 2, 2, 1]);
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];
const dy = tf.tensor(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [1, 4, 4, 1]);

const gradient =
tf.grad(t => tf.batchToSpaceND(t, blockShape, crops))(t, dy);
expect(gradient.shape).toEqual([4, 2, 2, 1]);
expectArraysClose(
await gradient.data(),
[1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16]);
});

it('gradient with clones, input=[4, 2, 2, 1], block shape=[2, 2]',
async () => {
const t = tf.tensor4d(
[1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16],
[4, 2, 2, 1]);
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];
const dy = tf.tensor(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
[1, 4, 4, 1]);

const gradient = tf.grad(
t => tf.batchToSpaceND(t.clone(), blockShape, crops).clone())(t, dy);
expect(gradient.shape).toEqual([4, 2, 2, 1]);
expectArraysClose(
await gradient.data(),
[1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16]);
});
});

describeWithFlags('spaceToBatchND', ALL_ENVS, () => {
it('tensor4d, input shape=[1, 2, 2, 1], blockShape=[2, 2]', async () => {
const t = tf.tensor4d([[[[1], [2]], [[3], [4]]]], [1, 2, 2, 1]);
Expand Down
Loading