From 50cbd3cd53523812f844b2e2a365287f1d814a64 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 25 Jun 2020 14:55:25 -0400 Subject: [PATCH 1/5] reshape --- tfjs-core/src/ops/all.ts | 2 +- tfjs-core/src/ops/any.ts | 2 +- tfjs-core/src/ops/fused_ops.ts | 2 +- tfjs-core/src/ops/reduction_ops.ts | 2 +- tfjs-core/src/ops/reduction_ops_util.ts | 4 +-- tfjs-core/src/ops/reshape.ts | 8 ++--- tfjs-core/src/ops/segment_ops.ts | 2 +- .../chained_ops/register_all_chained_ops.ts | 1 + .../register_all_chained_ops_test.ts | 1 + tfjs-core/src/public/chained_ops/reshape.ts | 30 +++++++++++++++++++ tfjs-core/src/tensor.ts | 27 +++++------------ 11 files changed, 50 insertions(+), 31 deletions(-) create mode 100644 tfjs-core/src/public/chained_ops/reshape.ts diff --git a/tfjs-core/src/ops/all.ts b/tfjs-core/src/ops/all.ts index 7a786380e12..1e7d382d289 100644 --- a/tfjs-core/src/ops/all.ts +++ b/tfjs-core/src/ops/all.ts @@ -70,7 +70,7 @@ function all_( const res = backend.all($x, axes); if (keepDims) { const newShape = expandShapeToKeepDim(res.shape, origAxes); - return res.reshape(newShape) as T; + return res.reshape(newShape); } return res as T; }; diff --git a/tfjs-core/src/ops/any.ts b/tfjs-core/src/ops/any.ts index 102d995ba47..6be067628eb 100644 --- a/tfjs-core/src/ops/any.ts +++ b/tfjs-core/src/ops/any.ts @@ -70,7 +70,7 @@ function any_( const res = backend.any($x, axes); if (keepDims) { const newShape = expandShapeToKeepDim(res.shape, origAxes); - return res.reshape(newShape) as T; + return res.reshape(newShape); } return res as T; }; diff --git a/tfjs-core/src/ops/fused_ops.ts b/tfjs-core/src/ops/fused_ops.ts index 5bc1e66fa49..f416d224a7d 100644 --- a/tfjs-core/src/ops/fused_ops.ts +++ b/tfjs-core/src/ops/fused_ops.ts @@ -256,7 +256,7 @@ function fusedMatMul_({ }, inputs, grad, '_FusedMatMul', {transposeA, transposeB, activation}, inputsToSave, outputsToSave); - return res.reshape(outShape) as T; + return res.reshape(outShape); } /** diff --git a/tfjs-core/src/ops/reduction_ops.ts b/tfjs-core/src/ops/reduction_ops.ts index a069aa1b434..8c1aea8f9f8 100644 --- a/tfjs-core/src/ops/reduction_ops.ts +++ b/tfjs-core/src/ops/reduction_ops.ts @@ -221,7 +221,7 @@ function min_( }, {x: $x}, grad, 'Min', {axes}, inputsToSave, outputsToSave); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); - res = res.reshape(newShape) as T; + res = res.reshape(newShape); } return res; } diff --git a/tfjs-core/src/ops/reduction_ops_util.ts b/tfjs-core/src/ops/reduction_ops_util.ts index a8a16d7d35e..cd5e58a0129 100644 --- a/tfjs-core/src/ops/reduction_ops_util.ts +++ b/tfjs-core/src/ops/reduction_ops_util.ts @@ -24,10 +24,10 @@ import * as axis_util from './axis_util'; export function gradForMinAndMax( dy: T, y: T, xOrig: Tensor, origAxes: number[], permutedAxes: number[]) { if (y.rank < xOrig.rank) { - y = y.reshape(axis_util.expandShapeToKeepDim(y.shape, origAxes)) as T; + y = y.reshape(axis_util.expandShapeToKeepDim(y.shape, origAxes)); } if (dy.rank < xOrig.rank) { - dy = dy.reshape(axis_util.expandShapeToKeepDim(dy.shape, origAxes)) as T; + dy = dy.reshape(axis_util.expandShapeToKeepDim(dy.shape, origAxes)); } return { x: () => { diff --git a/tfjs-core/src/ops/reshape.ts b/tfjs-core/src/ops/reshape.ts index f6db7332b21..9b4a6667283 100644 --- a/tfjs-core/src/ops/reshape.ts +++ b/tfjs-core/src/ops/reshape.ts @@ -52,17 +52,17 @@ import {op} from './operation'; * @param shape An array of integers defining the output tensor shape. */ /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ -function reshape_( - x: Tensor|TensorLike, shape: ShapeMap[R2]): Tensor { +function reshape_( + x: Tensor|TensorLike, shape: ShapeMap[R]): Tensor { const $x = convertToTensor(x, 'x', 'reshape', null); - shape = util.inferFromImplicitShape(shape, $x.size) as ShapeMap[R2]; + shape = util.inferFromImplicitShape(shape, $x.size) as ShapeMap[R]; util.assert( $x.size === util.sizeFromShape(shape), () => 'new shape and old shape must have the same number of elements.'); const inputs: ReshapeInputs = {tensor: $x}; const attrs: ReshapeAttrs = {shape}; - const forward: ForwardFunc> = + const forward: ForwardFunc> = (backend: KernelBackend, save: GradSaveFunc) => { save([$x]); return backend.reshape($x, shape); diff --git a/tfjs-core/src/ops/segment_ops.ts b/tfjs-core/src/ops/segment_ops.ts index f2f75bef89e..357579ae6bb 100644 --- a/tfjs-core/src/ops/segment_ops.ts +++ b/tfjs-core/src/ops/segment_ops.ts @@ -137,7 +137,7 @@ function gather_( return res; }, {x: $x, indices: $indices}, grad, 'Gather', {axis})) - .reshape(shapeInfo.outputShape) as T; + .reshape(shapeInfo.outputShape); } function arrayRange(start: number, stop: number): number[] { diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 0b9ae654b47..d9b725d28b9 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -64,6 +64,7 @@ import './pow'; import './prelu'; import './prod'; import './relu'; +import './reshape'; import './resize_bilinear'; import './resize_nearest_neighbor'; import './relu6'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 2349e72b026..055aecabc36 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -74,6 +74,7 @@ const CHAINED_OPS = [ 'prelu', 'prod', 'relu', + 'reshape', 'resizeBilinear', 'resizeNearestNeighbor', 'relu6', diff --git a/tfjs-core/src/public/chained_ops/reshape.ts b/tfjs-core/src/public/chained_ops/reshape.ts new file mode 100644 index 00000000000..6d2baa471b5 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/reshape.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {reshape} from '../../ops/reshape'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + reshape(shape: number[]): T; + } +} + +Tensor.prototype.reshape = function(shape: number[]): T { + this.throwIfDisposed(); + return reshape(this, shape) as T; +}; diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index ade0667f8ac..72fcb998675 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -174,7 +174,6 @@ export interface OpHandler { shape: ShapeMap[R], dtype: D, values?: DataTypeMap[D]): TensorBuffer; print(x: T, verbose: boolean): void; - reshape(x: Tensor, shape: ShapeMap[R2]): Tensor; clone(x: T): T; gather(x: T, indices: Tensor|TensorLike, axis: number): T; norm( @@ -364,14 +363,14 @@ export class Tensor { asScalar(): Scalar { this.throwIfDisposed(); util.assert(this.size === 1, () => 'The array must have only 1 element.'); - return this.reshape([]); + return this.reshape([]); } /** Converts a `tf.Tensor` to a `tf.Tensor1D`. */ /** @doc {heading: 'Tensors', subheading: 'Classes'} */ as1D(): Tensor1D { this.throwIfDisposed(); - return this.reshape([this.size]); + return this.reshape([this.size]); } /** @@ -383,7 +382,7 @@ export class Tensor { /** @doc {heading: 'Tensors', subheading: 'Classes'} */ as2D(rows: number, columns: number): Tensor2D { this.throwIfDisposed(); - return this.reshape([rows, columns]); + return this.reshape([rows, columns]); } /** @@ -396,7 +395,7 @@ export class Tensor { /** @doc {heading: 'Tensors', subheading: 'Classes'} */ as3D(rows: number, columns: number, depth: number): Tensor3D { this.throwIfDisposed(); - return this.reshape([rows, columns, depth]); + return this.reshape([rows, columns, depth]); } /** @@ -410,7 +409,7 @@ export class Tensor { /** @doc {heading: 'Tensors', subheading: 'Classes'} */ as4D(rows: number, columns: number, depth: number, depth2: number): Tensor4D { this.throwIfDisposed(); - return this.reshape([rows, columns, depth, depth2]); + return this.reshape([rows, columns, depth, depth2]); } /** @@ -427,7 +426,7 @@ export class Tensor { rows: number, columns: number, depth: number, depth2: number, depth3: number): Tensor5D { this.throwIfDisposed(); - return this.reshape([rows, columns, depth, depth2, depth3]); + return this.reshape([rows, columns, depth, depth2, depth3]); } /** @@ -584,18 +583,6 @@ export class Tensor { return opHandler.print(this, verbose); } - /** - * Reshapes the tensor into the provided shape. - * See `tf.reshape` for more details. - * - * @param newShape An array of integers defining the output tensor shape. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - reshape(newShape: ShapeMap[R2]): Tensor { - this.throwIfDisposed(); - return opHandler.reshape(this, newShape); - } - /** * Reshapes the tensor into the shape of the provided tensor. * @@ -604,7 +591,7 @@ export class Tensor { /** @doc {heading: 'Tensors', subheading: 'Classes'} */ reshapeAs(x: T): T { this.throwIfDisposed(); - return this.reshape(x.shape) as T; + return this.reshape(x.shape); } /** Returns a copy of the tensor. See `tf.clone` for details. */ From b1fa9a456261decc3493b6b4943574b5e9611f2a Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 25 Jun 2020 15:40:45 -0400 Subject: [PATCH 2/5] reshape --- tfjs-backend-cpu/src/backend_cpu.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tfjs-backend-cpu/src/backend_cpu.ts b/tfjs-backend-cpu/src/backend_cpu.ts index 6ee8e5e6e1a..65d996083e5 100644 --- a/tfjs-backend-cpu/src/backend_cpu.ts +++ b/tfjs-backend-cpu/src/backend_cpu.ts @@ -18,7 +18,7 @@ import * as tf from '@tensorflow/tfjs-core'; import {engine, env} from '@tensorflow/tfjs-core'; import {backend_util, buffer, slice_util, util} from '@tensorflow/tfjs-core'; -import {BackendTimingInfo, DataStorage, DataType, DataValues, KernelBackend, max, NumericDataType, Rank, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, TypedArray, upcastType} from '@tensorflow/tfjs-core'; +import {BackendTimingInfo, DataStorage, DataType, DataValues, KernelBackend, max, NumericDataType, Rank, reshape, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, TypedArray, upcastType} from '@tensorflow/tfjs-core'; import {kernel_impls} from '@tensorflow/tfjs-core'; const nonMaxSuppressionV3Impl = kernel_impls.nonMaxSuppressionV3Impl; @@ -2145,10 +2145,10 @@ export class MathBackendCPU extends KernelBackend { const flattenShape = backend_util.getReshapedPermuted( paddedX.shape, blockShape, prod, false); - return tf.transpose( - paddedX.reshape(reshapedPaddedShape), - permutedReshapedPaddedPermutation) - .reshape(flattenShape) as T; + const paddedXT = tf.transpose( + paddedX.reshape(reshapedPaddedShape), + permutedReshapedPaddedPermutation); + return reshape(paddedXT, flattenShape) as T; } maxPool(x: Tensor4D, convInfo: backend_util.Conv2DInfo): Tensor4D { From 25a8319dd891786115f3787b0a0448555393c6a4 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 25 Jun 2020 16:03:47 -0400 Subject: [PATCH 3/5] lint --- tfjs-backend-webgl/src/backend_webgl.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index c9c6e62a566..23a1b4c30fd 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -19,7 +19,7 @@ import './flags_webgl'; import * as tf from '@tensorflow/tfjs-core'; -import {complex, DataId, div, engine, env, imag, max, MemoryInfo, range, real, RecursiveArray, scalar, softmax, tensor, tidy, TimingInfo, transpose} from '@tensorflow/tfjs-core'; +import {complex, DataId, div, engine, env, imag, max, MemoryInfo, range, real, RecursiveArray, reshape, scalar, softmax, tensor, tidy, TimingInfo, transpose} from '@tensorflow/tfjs-core'; import {backend_util, buffer, kernel_impls, slice_util, util} from '@tensorflow/tfjs-core'; import {DataStorage, DataType, KernelBackend, NumericDataType, Rank, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorInfo, TypedArray, upcastType} from '@tensorflow/tfjs-core'; @@ -1020,10 +1020,10 @@ export class MathBackendWebGL extends KernelBackend { const flattenShape = backend_util.getReshapedPermuted( paddedX.shape, blockShape, prod, false); - return transpose( - paddedX.reshape(reshapedPaddedShape), - permutedReshapedPaddedPermutation) - .reshape(flattenShape) as T; + const paddedXT = transpose( + paddedX.reshape(reshapedPaddedShape), + permutedReshapedPaddedPermutation); + return reshape(paddedXT, flattenShape) as T; } private reduce( From 8c153c06a14247595c18ad1cc74c88bcc8463443 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 25 Jun 2020 16:43:46 -0400 Subject: [PATCH 4/5] renamegs --- tfjs-backend-wasm/src/kernels/Reshape.ts | 4 ++-- tfjs-core/src/kernel_names.ts | 2 +- tfjs-core/src/ops/reshape.ts | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/Reshape.ts b/tfjs-backend-wasm/src/kernels/Reshape.ts index 93c9c2c0993..41416faf0d1 100644 --- a/tfjs-backend-wasm/src/kernels/Reshape.ts +++ b/tfjs-backend-wasm/src/kernels/Reshape.ts @@ -25,9 +25,9 @@ export function reshape(args: { backend: BackendWasm }) { const {inputs, attrs} = args; - const {tensor} = inputs as {} as ReshapeInputs; + const {x} = inputs as {} as ReshapeInputs; const {shape} = attrs as {} as ReshapeAttrs; - return {dataId: tensor.dataId, shape, dtype: tensor.dtype}; + return {dataId: x.dataId, shape, dtype: x.dtype}; } registerKernel({ diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index be432b6d05b..b705fdb228f 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -438,7 +438,7 @@ export const Relu = 'Relu'; export type ReluInputs = Pick; export const Reshape = 'Reshape'; -export type ReshapeInputs = Pick; +export type ReshapeInputs = Pick; export interface ReshapeAttrs { shape: number[]; } diff --git a/tfjs-core/src/ops/reshape.ts b/tfjs-core/src/ops/reshape.ts index 9b4a6667283..fe0a5b246cd 100644 --- a/tfjs-core/src/ops/reshape.ts +++ b/tfjs-core/src/ops/reshape.ts @@ -60,7 +60,7 @@ function reshape_( $x.size === util.sizeFromShape(shape), () => 'new shape and old shape must have the same number of elements.'); - const inputs: ReshapeInputs = {tensor: $x}; + const inputs: ReshapeInputs = {x: $x}; const attrs: ReshapeAttrs = {shape}; const forward: ForwardFunc> = (backend: KernelBackend, save: GradSaveFunc) => { From 9f4b5e2d23569628c7026aee2d6da0df991357d1 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 26 Jun 2020 07:33:24 -0400 Subject: [PATCH 5/5] rename --- tfjs-core/src/gradients/Reshape_grad.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/gradients/Reshape_grad.ts b/tfjs-core/src/gradients/Reshape_grad.ts index dc0ddd93637..e6ddb455164 100644 --- a/tfjs-core/src/gradients/Reshape_grad.ts +++ b/tfjs-core/src/gradients/Reshape_grad.ts @@ -21,9 +21,9 @@ import {Tensor} from '../tensor'; export const reshapeGradConfig: GradConfig = { kernelName: Reshape, - inputsToSave: ['tensor'], + inputsToSave: ['x'], gradFunc: (dy: Tensor, saved: Tensor[]) => { - const [tensor] = saved; - return {tensor: () => reshape(dy, tensor.shape)}; + const [x] = saved; + return {x: () => reshape(dy, x.shape)}; } };