From b3235d0fa761341ad2402cc55fe8196cfcb9b659 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 15 May 2020 08:59:17 -0400 Subject: [PATCH] modularize --- tfjs-core/src/kernel_names.ts | 7 ++ tfjs-core/src/ops/array_ops.ts | 72 +----------- tfjs-core/src/ops/array_ops_test.ts | 61 ----------- tfjs-core/src/ops/depth_to_space.ts | 103 ++++++++++++++++++ tfjs-core/src/ops/depth_to_space_test.ts | 81 ++++++++++++++ tfjs-core/src/ops/ops.ts | 1 + .../src/public/chained_ops/depth_to_space.ts | 32 ++++++ .../chained_ops/register_all_chained_ops.ts | 1 + .../register_all_chained_ops_test.ts | 1 + tfjs-core/src/tensor.ts | 7 -- tfjs-core/src/tests.ts | 1 + 11 files changed, 229 insertions(+), 138 deletions(-) create mode 100644 tfjs-core/src/ops/depth_to_space.ts create mode 100644 tfjs-core/src/ops/depth_to_space_test.ts create mode 100644 tfjs-core/src/public/chained_ops/depth_to_space.ts diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 2a0af5b43a9..2101066febc 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -107,6 +107,13 @@ export interface Conv3DBackpropInputAttrs { pad: 'valid'|'same'; } +export const DepthToSpace = 'DepthToSpace'; +export type DepthToSpaceInputs = Pick; +export interface DepthToSpaceAttrs { + blockSize: number; + dataFormat: 'NHWC'|'NCHW'; +} + export const DepthwiseConv2dNative = 'DepthwiseConv2dNative'; export type DepthwiseConv2dNativeInputs = Pick; diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index 48cab082d49..98d48fd7f63 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -16,9 +16,9 @@ */ import {ENGINE} from '../engine'; -import {Tensor, Tensor4D, TensorBuffer} from '../tensor'; +import {Tensor, TensorBuffer} from '../tensor'; import {convertToTensor, convertToTensorArray} from '../tensor_util_env'; -import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike, TensorLike4D} from '../types'; +import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike} from '../types'; import * as util from '../util'; import {getAxesPermutation, getInnerMostAxes} from './axis_util'; import {concat} from './concat'; @@ -271,73 +271,6 @@ function expandDims_( return reshape($x, newShape as ShapeMap[R2]); } -/** - * Rearranges data from depth into blocks of spatial data. More specifically, - * this op outputs a copy of the input tensor where values from the `depth` - * dimension are moved in spatial blocks to the `height` and `width` dimensions. - * The attr `blockSize` indicates the input block size and how the data is - * moved. - * - * - Chunks of data of size `blockSize * blockSize` from depth are rearranged - * into non-overlapping blocks of size `blockSize x blockSize` - * - * - The width the output tensor is `inputWidth * blockSize`, whereas the - * height is `inputHeight * blockSize` - * - * - The Y, X coordinates within each block of the output image are determined - * by the high order component of the input channel index - * - * - The depth of the input tensor must be divisible by `blockSize * - * blockSize` - * - * The `dataFormat` attr specifies the layout of the input and output tensors - * with the following options: "NHWC": [ `batch, height, width, channels` ] - * "NCHW": [ `batch, channels, height, width` ] - * - * ```js - * const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]); - * const blockSize = 2; - * const dataFormat = "NHWC"; - * - * tf.depthToSpace(x, blockSize, dataFormat).print(); - * ``` - * - * @param x The input tensor of rank 4 - * @param blockSIze An `int` that is `>= 2`. The size of the spatial block - * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to "NHWC" - */ -/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ -function depthToSpace_( - x: Tensor4D|TensorLike4D, blockSize: number, - dataFormat: 'NHWC'|'NCHW' = 'NHWC'): Tensor4D { - const $x = convertToTensor(x, 'x', 'depthToSpace') as Tensor4D; - - const inputHeight = (dataFormat === 'NHWC') ? $x.shape[1] : $x.shape[2]; - const inputWidth = (dataFormat === 'NHWC') ? $x.shape[2] : $x.shape[3]; - const inputDepth = (dataFormat === 'NHWC') ? $x.shape[3] : $x.shape[1]; - - util.assert( - inputHeight * blockSize >= 0, - () => `Negative dimension size caused by overflow when multiplying - ${inputHeight} and ${blockSize} for depthToSpace with input shape - ${$x.shape}`); - - util.assert( - inputWidth * blockSize >= 0, - () => `Negative dimension size caused by overflow when multiplying - ${inputWidth} and ${blockSize} for depthToSpace with input shape - ${$x.shape}`); - - util.assert( - (inputDepth % (blockSize * blockSize) === 0), - () => `Dimension size must be evenly divisible by ${ - blockSize * blockSize} but is ${ - inputDepth} for depthToSpace with input shape ${$x.shape}`); - - return ENGINE.runKernelFunc( - backend => backend.depthToSpace($x, blockSize, dataFormat), {$x}); -} - /** * Computes the difference between two lists of numbers. * @@ -460,7 +393,6 @@ export { export const cast = op({cast_}); export const cumsum = op({cumsum_}); -export const depthToSpace = op({depthToSpace_}); export const expandDims = op({expandDims_}); export const reshape = op({reshape_}); export const squeeze = op({squeeze_}); diff --git a/tfjs-core/src/ops/array_ops_test.ts b/tfjs-core/src/ops/array_ops_test.ts index 60ed116b3d5..bf6ae3d510c 100644 --- a/tfjs-core/src/ops/array_ops_test.ts +++ b/tfjs-core/src/ops/array_ops_test.ts @@ -3100,67 +3100,6 @@ describeWithFlags('batchToSpaceND X spaceToBatchND', ALL_ENVS, () => { }); }); -describeWithFlags('depthToSpace', ALL_ENVS, () => { - it('tensor4d, input shape=[1, 1, 1, 4], blockSize=2, format=NHWC', - async () => { - const t = tf.tensor4d([[[[1, 2, 3, 4]]]]); - const blockSize = 2; - const dataFormat = 'NHWC'; - - const res = tf.depthToSpace(t, blockSize, dataFormat); - expect(res.shape).toEqual([1, 2, 2, 1]); - expectArraysClose(await res.data(), [1, 2, 3, 4]); - }); - - it('tensor4d, input shape=[1, 1, 1, 12], blockSize=2, format=NHWC', - async () => { - const t = tf.tensor4d([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]]); - const blockSize = 2; - const dataFormat = 'NHWC'; - - const res = tf.depthToSpace(t, blockSize, dataFormat); - expect(res.shape).toEqual([1, 2, 2, 3]); - expectArraysClose( - await res.data(), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - }); - - it('tensor4d, input shape=[1, 2, 2, 4], blockSize=2, format=NHWC', - async () => { - const t = tf.tensor4d([ - [[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]]] - ]); - const blockSize = 2; - const dataFormat = 'NHWC'; - - const res = tf.depthToSpace(t, blockSize, dataFormat); - expect(res.shape).toEqual([1, 4, 4, 1]); - expectArraysClose( - await res.data(), - [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16]); - }); - - it('throws when depth not divisible by blockSize * blockSize', () => { - const t = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]); - const blockSize = 3; - - expect(() => tf.depthToSpace(t, blockSize)) - .toThrowError(`Dimension size must be evenly divisible by ${ - blockSize * blockSize} but is ${ - t.shape[3]} for depthToSpace with input shape ${t.shape}`); - }); -}); - -describeWithFlags('depthToSpace', BROWSER_ENVS, () => { - it('throws when blocksize < 2', () => { - const t = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]); - const blockSize = 1; - - expect(() => tf.depthToSpace(t, blockSize)) - .toThrowError( - `blockSize should be > 1 for depthToSpace, but was: ${blockSize}`); - }); -}); - describeWithFlags('setdiff1dAsync', ALL_ENVS, () => { it('1d int32 tensor', async () => { const x = tf.tensor1d([1, 2, 3, 4], 'int32'); diff --git a/tfjs-core/src/ops/depth_to_space.ts b/tfjs-core/src/ops/depth_to_space.ts new file mode 100644 index 00000000000..98bb74d17cb --- /dev/null +++ b/tfjs-core/src/ops/depth_to_space.ts @@ -0,0 +1,103 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {DepthToSpace, DepthToSpaceAttrs, DepthToSpaceInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor4D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike4D} from '../types'; +import * as util from '../util'; + +import {op} from './operation'; + +/** + * Rearranges data from depth into blocks of spatial data. More specifically, + * this op outputs a copy of the input tensor where values from the `depth` + * dimension are moved in spatial blocks to the `height` and `width` dimensions. + * The attr `blockSize` indicates the input block size and how the data is + * moved. + * + * - Chunks of data of size `blockSize * blockSize` from depth are rearranged + * into non-overlapping blocks of size `blockSize x blockSize` + * + * - The width the output tensor is `inputWidth * blockSize`, whereas the + * height is `inputHeight * blockSize` + * + * - The Y, X coordinates within each block of the output image are determined + * by the high order component of the input channel index + * + * - The depth of the input tensor must be divisible by `blockSize * + * blockSize` + * + * The `dataFormat` attr specifies the layout of the input and output tensors + * with the following options: "NHWC": [ `batch, height, width, channels` ] + * "NCHW": [ `batch, channels, height, width` ] + * + * ```js + * const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]); + * const blockSize = 2; + * const dataFormat = "NHWC"; + * + * tf.depthToSpace(x, blockSize, dataFormat).print(); + * ``` + * + * @param x The input tensor of rank 4 + * @param blockSIze An `int` that is `>= 2`. The size of the spatial block + * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to "NHWC" + */ +/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ +function depthToSpace_( + x: Tensor4D|TensorLike4D, blockSize: number, + dataFormat: 'NHWC'|'NCHW' = 'NHWC'): Tensor4D { + const $x = convertToTensor(x, 'x', 'depthToSpace') as Tensor4D; + + const inputHeight = (dataFormat === 'NHWC') ? $x.shape[1] : $x.shape[2]; + const inputWidth = (dataFormat === 'NHWC') ? $x.shape[2] : $x.shape[3]; + const inputDepth = (dataFormat === 'NHWC') ? $x.shape[3] : $x.shape[1]; + + util.assert( + inputHeight * blockSize >= 0, + () => `Negative dimension size caused by overflow when multiplying + ${inputHeight} and ${blockSize} for depthToSpace with input shape + ${$x.shape}`); + + util.assert( + inputWidth * blockSize >= 0, + () => `Negative dimension size caused by overflow when multiplying + ${inputWidth} and ${blockSize} for depthToSpace with input shape + ${$x.shape}`); + + util.assert( + (inputDepth % (blockSize * blockSize) === 0), + () => `Dimension size must be evenly divisible by ${ + blockSize * blockSize} but is ${ + inputDepth} for depthToSpace with input shape ${$x.shape}`); + + const forward: ForwardFunc = backend => + backend.depthToSpace($x, blockSize, dataFormat); + + const inputs: DepthToSpaceInputs = {x: $x}; + const attrs: DepthToSpaceAttrs = {blockSize, dataFormat}; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* gradient */, + DepthToSpace, attrs as {} as NamedAttrMap); +} + +export const depthToSpace = op({depthToSpace_}); diff --git a/tfjs-core/src/ops/depth_to_space_test.ts b/tfjs-core/src/ops/depth_to_space_test.ts new file mode 100644 index 00000000000..33a191f61ba --- /dev/null +++ b/tfjs-core/src/ops/depth_to_space_test.ts @@ -0,0 +1,81 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, BROWSER_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('depthToSpace', ALL_ENVS, () => { + it('tensor4d, input shape=[1, 1, 1, 4], blockSize=2, format=NHWC', + async () => { + const t = tf.tensor4d([[[[1, 2, 3, 4]]]]); + const blockSize = 2; + const dataFormat = 'NHWC'; + + const res = tf.depthToSpace(t, blockSize, dataFormat); + expect(res.shape).toEqual([1, 2, 2, 1]); + expectArraysClose(await res.data(), [1, 2, 3, 4]); + }); + + it('tensor4d, input shape=[1, 1, 1, 12], blockSize=2, format=NHWC', + async () => { + const t = tf.tensor4d([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]]); + const blockSize = 2; + const dataFormat = 'NHWC'; + + const res = tf.depthToSpace(t, blockSize, dataFormat); + expect(res.shape).toEqual([1, 2, 2, 3]); + expectArraysClose( + await res.data(), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + }); + + it('tensor4d, input shape=[1, 2, 2, 4], blockSize=2, format=NHWC', + async () => { + const t = tf.tensor4d([ + [[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]]] + ]); + const blockSize = 2; + const dataFormat = 'NHWC'; + + const res = tf.depthToSpace(t, blockSize, dataFormat); + expect(res.shape).toEqual([1, 4, 4, 1]); + expectArraysClose( + await res.data(), + [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16]); + }); + + it('throws when depth not divisible by blockSize * blockSize', () => { + const t = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]); + const blockSize = 3; + + expect(() => tf.depthToSpace(t, blockSize)) + .toThrowError(`Dimension size must be evenly divisible by ${ + blockSize * blockSize} but is ${ + t.shape[3]} for depthToSpace with input shape ${t.shape}`); + }); +}); + +describeWithFlags('depthToSpace', BROWSER_ENVS, () => { + it('throws when blocksize < 2', () => { + const t = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]); + const blockSize = 1; + + expect(() => tf.depthToSpace(t, blockSize)) + .toThrowError( + `blockSize should be > 1 for depthToSpace, but was: ${blockSize}`); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 64db94702d9..98564392a09 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -35,6 +35,7 @@ export {conv2d} from './conv2d'; export {conv2dTranspose} from './conv2d_transpose'; export {conv3d} from './conv3d'; export {conv3dTranspose} from './conv3d_transpose'; +export {depthToSpace} from './depth_to_space'; export {depthwiseConv2d} from './depthwise_conv2d'; export {diag} from './diag'; export {div} from './div'; diff --git a/tfjs-core/src/public/chained_ops/depth_to_space.ts b/tfjs-core/src/public/chained_ops/depth_to_space.ts new file mode 100644 index 00000000000..cf4eadd827c --- /dev/null +++ b/tfjs-core/src/public/chained_ops/depth_to_space.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {depthToSpace} from '../../ops/depth_to_space'; +import {Tensor, Tensor4D} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + depthToSpace( + blockSize: number, dataFormat: 'NHWC'|'NCHW'): T; + } +} + +Tensor.prototype.depthToSpace = function( + blockSize: number, dataFormat: 'NHWC'|'NCHW'): T { + this.throwIfDisposed(); + return depthToSpace(this, blockSize, dataFormat) as T; +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index ba785603de0..05d39a9b729 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -23,6 +23,7 @@ import './concat'; import './conv1d'; import './conv2d'; import './conv2d_transpose'; +import './depth_to_space'; import './depthwise_conv2d'; import './depthwise_conv2D_deprecated'; import './div'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 2378927d7b1..da1641864bb 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -32,6 +32,7 @@ const CHAINED_OPS = [ 'conv1d', 'conv2d', 'conv2dTranspose', + 'depthToSpace', 'depthwiseConv2d', 'depthwiseConv2D', 'div', diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 8ccfe74214f..14ed25f29bf 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -301,7 +301,6 @@ export interface OpHandler { x: Tensor, begin: number[], end: number[], strides: number[], beginMask: number, endMask: number, ellipsisMask: number, newAxisMask: number, shrinkAxisMask: number): Tensor; - depthToSpace(x: Tensor4D, blockSize: number, dataFormat: string): Tensor4D; spectral: { fft(x: Tensor): Tensor; ifft(x: Tensor): Tensor; rfft(x: Tensor): Tensor; irfft(x: Tensor): Tensor @@ -1196,12 +1195,6 @@ export class Tensor { newAxisMask, shrinkAxisMask); } - depthToSpace(this: Tensor4D, blockSize: number, dataFormat: 'NHWC'|'NCHW'): - Tensor4D { - this.throwIfDisposed(); - return opHandler.depthToSpace(this, blockSize, dataFormat); - } - fft(this: Tensor): Tensor { this.throwIfDisposed(); return opHandler.spectral.fft(this); diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index b73e7e87338..4d7e5f3b638 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -63,6 +63,7 @@ import './ops/conv2d_transpose_test'; import './ops/conv3d_test'; import './ops/conv3d_transpose_test'; import './ops/conv_util_test'; +import './ops/depth_to_space_test'; import './ops/diag_test'; import './ops/dropout_test'; import './ops/dropout_util_test';