diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index 6880be3269b..42d88224ce3 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -35,6 +35,7 @@ import {buffer, scalar, tensor, tensor4d} from '../../ops/ops'; import * as scatter_nd_util from '../../ops/scatter_nd_util'; import * as selu_util from '../../ops/selu_util'; import {computeFlatOffset, computeOutShape, isSliceContinous} from '../../ops/slice_util'; +import {transpose} from '../../ops/transpose'; import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor'; import {BackendValues, DataType, DataValues, NumericDataType, Rank, ShapeMap, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; @@ -2111,32 +2112,6 @@ export class MathBackendCPU extends KernelBackend { return buffer.toTensor() as T; } - transpose(x: T, perm: number[]): T { - assertNotComplex(x, 'transpose'); - - const newShape: number[] = new Array(x.rank); - for (let i = 0; i < newShape.length; i++) { - newShape[i] = x.shape[perm[i]]; - } - const values = this.readSync(x.dataId) as TypedArray; - const result = buffer(newShape, x.dtype); - - const xBuf = this.bufferSync(x); - for (let i = 0; i < x.size; ++i) { - const loc = xBuf.indexToLoc(i); - - // Permute location. - const newLoc: number[] = new Array(loc.length); - for (let i = 0; i < newLoc.length; i++) { - newLoc[i] = loc[perm[i]]; - } - - const newIndex = result.locToIndex(newLoc); - result.values[newIndex] = values[i]; - } - return result.toTensor() as T; - } - gather(x: T, indices: Tensor1D, axis: number): T { assertNotComplex([x, indices], 'gather'); @@ -2174,8 +2149,7 @@ export class MathBackendCPU extends KernelBackend { const sliceSize = array_ops_util.getSliceSize(reshapedPermuted, crops, blockShape.length); - return x.reshape(reshaped) - .transpose(permuted) + return transpose(x.reshape(reshaped), permuted) .reshape(reshapedPermuted) .slice(sliceBeginCoords, sliceSize) as T; } @@ -2201,8 +2175,9 @@ export class MathBackendCPU extends KernelBackend { const flattenShape = array_ops_util.getReshapedPermuted( paddedX.shape, blockShape, prod, false); - return paddedX.reshape(reshapedPaddedShape) - .transpose(permutedReshapedPaddedPermutation) + return transpose( + paddedX.reshape(reshapedPaddedShape), + permutedReshapedPaddedPermutation) .reshape(flattenShape) as T; } diff --git a/tfjs-core/src/backends/cpu/kernels/Transpose.ts b/tfjs-core/src/backends/cpu/kernels/Transpose.ts new file mode 100644 index 00000000000..cfd905c4641 --- /dev/null +++ b/tfjs-core/src/backends/cpu/kernels/Transpose.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Transpose, TransposeAttrs, TransposeInputs} from '../../../kernel_names'; +import {KernelConfig} from '../../../kernel_registry'; +import {TypedArray} from '../../../types'; +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +import {transposeImpl} from './Transpose_impl'; + +export const transposeConfig: KernelConfig = { + kernelName: Transpose, + backendName: 'cpu', + kernelFunc: ({inputs, attrs, backend}) => { + const {x} = inputs as TransposeInputs; + const {perm} = attrs as {} as TransposeAttrs; + const cpuBackend = backend as MathBackendCPU; + + assertNotComplex(x, 'transpose'); + + const xRank = x.shape.length; + + const newShape: number[] = new Array(xRank); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = x.shape[perm[i]]; + } + + const values = cpuBackend.data.get(x.dataId).values as TypedArray; + const result = transposeImpl(values, x.shape, x.dtype, perm, newShape); + + const dataId = cpuBackend.write(result, newShape, x.dtype); + return {dataId, shape: newShape, dtype: x.dtype}; + } +}; diff --git a/tfjs-core/src/backends/cpu/kernels/Transpose_impl.ts b/tfjs-core/src/backends/cpu/kernels/Transpose_impl.ts new file mode 100644 index 00000000000..80ccf480f16 --- /dev/null +++ b/tfjs-core/src/backends/cpu/kernels/Transpose_impl.ts @@ -0,0 +1,45 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {DataType, NumericDataType, TypedArray} from '../../../types'; +import * as util from '../../../util'; + +export function transposeImpl( + xVals: TypedArray, xShape: number[], dtype: DataType, perm: number[], + newShape: number[]): TypedArray { + const xSize = util.sizeFromShape(xShape); + const xRank = xShape.length; + const xStrides = util.computeStrides(xShape); + const newStrides = util.computeStrides(newShape); + + const result = util.getTypedArrayFromDType( + dtype as NumericDataType, util.sizeFromShape(newShape)); + + for (let i = 0; i < xSize; ++i) { + const loc = util.indexToLoc(i, xRank, xStrides); + + // Permute location. + const newLoc: number[] = new Array(loc.length); + for (let i = 0; i < newLoc.length; i++) { + newLoc[i] = loc[perm[i]]; + } + + const newIndex = util.locToIndex(newLoc, xRank, newStrides); + result[newIndex] = xVals[i]; + } + return result; +} diff --git a/tfjs-core/src/backends/cpu/register_all_kernels.ts b/tfjs-core/src/backends/cpu/register_all_kernels.ts index 4d7f2e983cf..0a1055fd93b 100644 --- a/tfjs-core/src/backends/cpu/register_all_kernels.ts +++ b/tfjs-core/src/backends/cpu/register_all_kernels.ts @@ -23,10 +23,12 @@ import {divConfig} from './kernels/Div'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; +import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig + nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig, + transposeConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index 64d29147b21..ccdc27d8e80 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -39,6 +39,7 @@ import * as segment_util from '../../ops/segment_util'; import * as slice_util from '../../ops/slice_util'; import {softmax} from '../../ops/softmax'; import {range, scalar, tensor} from '../../ops/tensor_ops'; +import {transpose} from '../../ops/transpose'; import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor'; import {BackendValues, DataType, DataTypeMap, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; @@ -125,8 +126,6 @@ import * as tex_util from './tex_util'; import {TextureData, TextureUsage} from './tex_util'; import {TextureManager} from './texture_manager'; import {TileProgram} from './tile_gpu'; -import {TransposeProgram} from './transpose_gpu'; -import {TransposePackedProgram} from './transpose_packed_gpu'; import * as unary_op from './unaryop_gpu'; import {UnaryOpProgram} from './unaryop_gpu'; import * as unary_packed_op from './unaryop_packed_gpu'; @@ -653,12 +652,13 @@ export class MathBackendWebGL extends KernelBackend { TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more sustainable strategy for optimizing backend execution of ops. */ - private shouldExecuteOnCPU( - inputs: Tensor[], sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD): boolean { + shouldExecuteOnCPU( + inputs: TensorInfo[], + sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD): boolean { return this.getCPUBackend() != null && inputs.every( input => this.texData.get(input.dataId).texture == null && - input.size < sizeThreshold); + util.sizeFromShape(input.shape) < sizeThreshold); } getGPGPUContext(): GPGPUContext { @@ -821,10 +821,10 @@ export class MathBackendWebGL extends KernelBackend { if ((outerShapeA === 1 || outerShapeB === 1) && sharedDim > MATMUL_SHARED_DIM_THRESHOLD) { if (transposeA) { - a = a.transpose([0, 2, 1]); + a = transpose(a, [0, 2, 1]); } if (transposeB) { - b = b.transpose([0, 2, 1]); + b = transpose(b, [0, 2, 1]); } const a3D = outerShapeB === 1 ? a : a.as3D(batch, sharedDim, 1); @@ -969,16 +969,6 @@ export class MathBackendWebGL extends KernelBackend { return this.compileAndRun(program, [x]); } - transpose(x: T, perm: number[]): T { - if (this.shouldExecuteOnCPU([x])) { - return this.cpuBackend.transpose(x, perm); - } - const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? - new TransposePackedProgram(x.shape, perm) : - new TransposeProgram(x.shape, perm); - return this.compileAndRun(program, [x]); - } - gather(x: T, indices: Tensor1D, axis: number): T { if (this.shouldExecuteOnCPU([x, indices])) { return this.cpuBackend.gather(x, indices, axis); @@ -1005,8 +995,7 @@ export class MathBackendWebGL extends KernelBackend { const sliceSize = array_ops_util.getSliceSize(reshapedPermuted, crops, blockShape.length); - return x.reshape(reshaped) - .transpose(permuted) + return transpose(x.reshape(reshaped), permuted) .reshape(reshapedPermuted) .slice(sliceBeginCoords, sliceSize) as T; } @@ -1037,8 +1026,9 @@ export class MathBackendWebGL extends KernelBackend { const flattenShape = array_ops_util.getReshapedPermuted( paddedX.shape, blockShape, prod, false); - return paddedX.reshape(reshapedPaddedShape) - .transpose(permutedReshapedPaddedPermutation) + return transpose( + paddedX.reshape(reshapedPaddedShape), + permutedReshapedPaddedPermutation) .reshape(flattenShape) as T; } @@ -1127,7 +1117,7 @@ export class MathBackendWebGL extends KernelBackend { const permutation = axis_util.getAxesPermutation([axis], x.rank); let permutedX = x; if (permutation != null) { - permutedX = x.transpose(permutation); + permutedX = transpose(x, permutation); axis = axis_util.getInnerMostAxes(1, x.rank)[0]; } @@ -1141,7 +1131,7 @@ export class MathBackendWebGL extends KernelBackend { a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments) .reshape(outShape); if (permutation != null) { - result = result.transpose(axis_util.getUndoAxesPermutation(permutation)); + result = transpose(result, axis_util.getUndoAxesPermutation(permutation)); } return result; } diff --git a/tfjs-core/src/backends/webgl/kernels/Transpose.ts b/tfjs-core/src/backends/webgl/kernels/Transpose.ts new file mode 100644 index 00000000000..086923e5eaa --- /dev/null +++ b/tfjs-core/src/backends/webgl/kernels/Transpose.ts @@ -0,0 +1,54 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {TypedArray} from '../../../../src/types'; +import {transposeImpl as cpuTranspose} from '../../../backends/cpu/kernels/Transpose_impl'; +import {Transpose, TransposeAttrs, TransposeInputs} from '../../../kernel_names'; +import {KernelConfig, TensorInfo} from '../../../kernel_registry'; +import {MathBackendWebGL} from '../backend_webgl'; +import {transposeImpl} from './Transpose_impl'; + +export const transposeConfig: KernelConfig = { + kernelName: Transpose, + backendName: 'webgl', + kernelFunc: ({inputs, attrs, backend}) => { + const {x} = inputs as TransposeInputs; + const {perm} = attrs as {} as TransposeAttrs; + const webglBackend = backend as MathBackendWebGL; + + const xRank = x.shape.length; + + const newShape: number[] = new Array(xRank); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = x.shape[perm[i]]; + } + + let out: TensorInfo; + if (webglBackend.shouldExecuteOnCPU([x])) { + const xTexData = webglBackend.texData.get(x.dataId); + const values = xTexData.values as TypedArray; + const outValues = cpuTranspose(values, x.shape, x.dtype, perm, newShape); + + out = webglBackend.makeTensorInfo(newShape, x.dtype); + const outData = webglBackend.texData.get(out.dataId); + outData.values = outValues; + } else { + out = transposeImpl(x, perm, webglBackend); + } + return out; + } +}; diff --git a/tfjs-core/src/backends/webgl/kernels/Transpose_impl.ts b/tfjs-core/src/backends/webgl/kernels/Transpose_impl.ts new file mode 100644 index 00000000000..498aae12529 --- /dev/null +++ b/tfjs-core/src/backends/webgl/kernels/Transpose_impl.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {env} from '../../../environment'; +import {TensorInfo} from '../../../kernel_registry'; +import {MathBackendWebGL} from '../backend_webgl'; +import {TransposeProgram} from '../transpose_gpu'; +import {TransposePackedProgram} from '../transpose_packed_gpu'; + +export function transposeImpl( + x: TensorInfo, perm: number[], backend: MathBackendWebGL): TensorInfo { + const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + new TransposePackedProgram(x.shape, perm) : + new TransposeProgram(x.shape, perm); + return backend.runWebGLProgram(program, [x], x.dtype); +} diff --git a/tfjs-core/src/backends/webgl/register_all_kernels.ts b/tfjs-core/src/backends/webgl/register_all_kernels.ts index fe04734d2cb..0a5351cf813 100644 --- a/tfjs-core/src/backends/webgl/register_all_kernels.ts +++ b/tfjs-core/src/backends/webgl/register_all_kernels.ts @@ -21,6 +21,7 @@ import {fromPixelsConfig} from './kernels/FromPixels'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; +import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ @@ -29,6 +30,7 @@ const kernelConfigs: KernelConfig[] = [ nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, + transposeConfig, ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-core/src/backends/webgl/webgl_ops_test.ts b/tfjs-core/src/backends/webgl/webgl_ops_test.ts index 3dc5c1ebab1..93a29f004d5 100644 --- a/tfjs-core/src/backends/webgl/webgl_ops_test.ts +++ b/tfjs-core/src/backends/webgl/webgl_ops_test.ts @@ -438,8 +438,9 @@ describeWithFlags('gramSchmidt-non-tiny', WEBGL_ENVS, () => { // can complete in the timeout limit of the unit test. const xs: Tensor2D = tf.randomUniform([8, 16]); const y = tf.linalg.gramSchmidt(xs) as Tensor2D; + const yTransposed: Tensor2D = y.transpose(); expectArraysClose( - await y.matMul(y.transpose()).data(), await tf.eye(8).data()); + await y.matMul(yTransposed).data(), await tf.eye(8).data()); }); }); diff --git a/tfjs-core/src/gradients/Transpose_grad.ts b/tfjs-core/src/gradients/Transpose_grad.ts new file mode 100644 index 00000000000..4d48500d7de --- /dev/null +++ b/tfjs-core/src/gradients/Transpose_grad.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Transpose, TransposeAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import * as axis_util from '../ops/axis_util'; +import {transpose} from '../ops/transpose'; +import {Tensor} from '../tensor'; + +export const transposeGradConfig: GradConfig = { + kernelName: Transpose, + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const transposeAttrs: TransposeAttrs = attrs as {} as TransposeAttrs; + const {perm} = transposeAttrs; + const undoPerm = axis_util.getUndoAxesPermutation(perm); + return {x: () => transpose(dy, undoPerm)}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 8ecdb5bfe7e..239d9794a46 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -32,6 +32,12 @@ export type SquaredDifferenceInputs = BinaryInputs; export const Square = 'Square'; export type SquareInputs = Pick; +export const Transpose = 'Transpose'; +export type TransposeInputs = Pick; +export interface TransposeAttrs { + perm: number[]; +} + export const NonMaxSuppressionV5 = 'NonMaxSuppressionV5'; export type NonMaxSuppressionV5Inputs = Pick; diff --git a/tfjs-core/src/ops/confusion_matrix.ts b/tfjs-core/src/ops/confusion_matrix.ts index defb6cd3cc2..60d8c465c6c 100644 --- a/tfjs-core/src/ops/confusion_matrix.ts +++ b/tfjs-core/src/ops/confusion_matrix.ts @@ -84,7 +84,8 @@ export function confusionMatrix_( const oneHotLabels = oneHot($labels.asType('int32'), numClasses) as Tensor2D; const oneHotPredictions = oneHot($predictions.asType('int32'), numClasses) as Tensor2D; - return oneHotLabels.transpose().matMul(oneHotPredictions).asType('int32'); + const oneHotLabelsT: Tensor2D = oneHotLabels.transpose(); + return oneHotLabelsT.matMul(oneHotPredictions).asType('int32'); } export const confusionMatrix = op({confusionMatrix_}); diff --git a/tfjs-core/src/ops/linalg_ops.ts b/tfjs-core/src/ops/linalg_ops.ts index c8af437a70c..cdf721f4ffc 100644 --- a/tfjs-core/src/ops/linalg_ops.ts +++ b/tfjs-core/src/ops/linalg_ops.ts @@ -331,19 +331,21 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] { // -- R := HR, Q := QH. const rjEndAll = r.slice([j, 0], [m - j, n]); const tauTimesW: Tensor2D = tau.mul(w); + const wT: Tensor2D = w.transpose(); if (j === 0) { - r = rjEndAll.sub(tauTimesW.matMul(w.transpose().matMul(rjEndAll))); + r = rjEndAll.sub(tauTimesW.matMul(wT.matMul(rjEndAll))); } else { const rTimesTau: Tensor2D = - rjEndAll.sub(tauTimesW.matMul(w.transpose().matMul(rjEndAll))); + rjEndAll.sub(tauTimesW.matMul(wT.matMul(rjEndAll))); r = r.slice([0, 0], [j, n]).concat(rTimesTau, 0); } + const tawTimesWT: Tensor2D = tauTimesW.transpose(); const qAllJEnd = q.slice([0, j], [m, q.shape[1] - j]); if (j === 0) { - q = qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tauTimesW.transpose())); + q = qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tawTimesWT)); } else { const qTimesTau: Tensor2D = - qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tauTimesW.transpose())); + qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tawTimesWT)); q = q.slice([0, 0], [m, j]).concat(qTimesTau, 1); } return [w, r, q]; diff --git a/tfjs-core/src/ops/linalg_ops_test.ts b/tfjs-core/src/ops/linalg_ops_test.ts index c2baa84b679..bb5d76a413f 100644 --- a/tfjs-core/src/ops/linalg_ops_test.ts +++ b/tfjs-core/src/ops/linalg_ops_test.ts @@ -95,127 +95,90 @@ describeWithFlags('bandPart', ALL_ENVS, () => { it('fails for scalar', async () => { const x = scalar(1); - expect( () => la.bandPart(x, 1, 2) ).toThrowError(/bandPart.*rank/i); + expect(() => la.bandPart(x, 1, 2)).toThrowError(/bandPart.*rank/i); }); it('fails for 1D tensor', async () => { const x = tensor1d([1, 2, 3, 4, 5]); - expect( () => la.bandPart(x, 1, 2) ).toThrowError(/bandPart.*rank/i); + expect(() => la.bandPart(x, 1, 2)).toThrowError(/bandPart.*rank/i); }); it('fails if numLower or numUpper too large', async () => { - const a = tf.tensor2d([[1, 2, 3], - [4, 5, 6]]); + const a = tf.tensor2d([[1, 2, 3], [4, 5, 6]]); - for( const numLower of [ 3,5,8,13] ) { - for( const numUpper of [-1,0,1, 2] ) { - expect( () => tf.linalg.bandPart(a, numLower, numUpper) ) - .toThrowError(/bandPart.*numLower/i); - }} + for (const numLower of [3, 5, 8, 13]) { + for (const numUpper of [-1, 0, 1, 2]) { + expect(() => tf.linalg.bandPart(a, numLower, numUpper)) + .toThrowError(/bandPart.*numLower/i); + } + } - for( const numLower of [-1,0,1] ) { - for( const numUpper of [ 4,5,9] ) { - expect( () => tf.linalg.bandPart(a, numLower, numUpper) ) - .toThrowError(/bandPart.*numUpper/i); - }} + for (const numLower of [-1, 0, 1]) { + for (const numUpper of [4, 5, 9]) { + expect(() => tf.linalg.bandPart(a, numLower, numUpper)) + .toThrowError(/bandPart.*numUpper/i); + } + } - for( const numLower of [ 3,5,8,13] ) { - for( const numUpper of [ 4,5, 9] ) { - expect( () => tf.linalg.bandPart(a, numLower, numUpper) ) - .toThrowError(/bandPart.*(numLower|numUpper)/i); - }} + for (const numLower of [3, 5, 8, 13]) { + for (const numUpper of [4, 5, 9]) { + expect(() => tf.linalg.bandPart(a, numLower, numUpper)) + .toThrowError(/bandPart.*(numLower|numUpper)/i); + } + } }); it('works for 3x4 example', async () => { - const a = tf.tensor2d([[1, 2, 3, 4], - [5, 6, 7, 8], - [9,10,11,12]]); + const a = tf.tensor2d([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]); expectArraysClose( - await la.bandPart(a,0,0).array(), - [[1, 0, 0, 0], - [0, 6, 0, 0], - [0, 0,11, 0]] - ); + await la.bandPart(a, 0, 0).array(), + [[1, 0, 0, 0], [0, 6, 0, 0], [0, 0, 11, 0]]); expectArraysClose( - await la.bandPart(a,0,1).array(), - [[1, 2, 0, 0], - [0, 6, 7, 0], - [0, 0,11,12]] - ); + await la.bandPart(a, 0, 1).array(), + [[1, 2, 0, 0], [0, 6, 7, 0], [0, 0, 11, 12]]); expectArraysClose( - await la.bandPart(a,0,2).array(), - [[1, 2, 3, 0], - [0, 6, 7, 8], - [0, 0,11,12]] - ); - for( const numUpper of [3,4,-1,-2] ) { + await la.bandPart(a, 0, 2).array(), + [[1, 2, 3, 0], [0, 6, 7, 8], [0, 0, 11, 12]]); + for (const numUpper of [3, 4, -1, -2]) { expectArraysClose( - await la.bandPart(a,0,numUpper).array(), - [[1, 2, 3, 4], - [0, 6, 7, 8], - [0, 0,11,12]] - ); + await la.bandPart(a, 0, numUpper).array(), + [[1, 2, 3, 4], [0, 6, 7, 8], [0, 0, 11, 12]]); } expectArraysClose( - await la.bandPart(a,1,0).array(), - [[1, 0, 0, 0], - [5, 6, 0, 0], - [0,10,11, 0]] - ); + await la.bandPart(a, 1, 0).array(), + [[1, 0, 0, 0], [5, 6, 0, 0], [0, 10, 11, 0]]); expectArraysClose( - await la.bandPart(a,1,1).array(), - [[1, 2, 0, 0], - [5, 6, 7, 0], - [0,10,11,12]] - ); + await la.bandPart(a, 1, 1).array(), + [[1, 2, 0, 0], [5, 6, 7, 0], [0, 10, 11, 12]]); expectArraysClose( - await la.bandPart(a,1,2).array(), - [[1, 2, 3, 0], - [5, 6, 7, 8], - [0,10,11,12]] - ); - for( const numUpper of [3,4,-1,-2] ) { + await la.bandPart(a, 1, 2).array(), + [[1, 2, 3, 0], [5, 6, 7, 8], [0, 10, 11, 12]]); + for (const numUpper of [3, 4, -1, -2]) { expectArraysClose( - await la.bandPart(a,1,numUpper).array(), - [[1, 2, 3, 4], - [5, 6, 7, 8], - [0,10,11,12]] - ); + await la.bandPart(a, 1, numUpper).array(), + [[1, 2, 3, 4], [5, 6, 7, 8], [0, 10, 11, 12]]); } - for( const numLower of [2,3,-1,-2]) - { + for (const numLower of [2, 3, -1, -2]) { expectArraysClose( - await la.bandPart(a,numLower,0).array(), - [[1, 0, 0, 0], - [5, 6, 0, 0], - [9,10,11, 0]] - ); + await la.bandPart(a, numLower, 0).array(), + [[1, 0, 0, 0], [5, 6, 0, 0], [9, 10, 11, 0]]); expectArraysClose( - await la.bandPart(a,numLower,1).array(), - [[1, 2, 0, 0], - [5, 6, 7, 0], - [9,10,11,12]] - ); + await la.bandPart(a, numLower, 1).array(), + [[1, 2, 0, 0], [5, 6, 7, 0], [9, 10, 11, 12]]); expectArraysClose( - await la.bandPart(a,numLower,2).array(), - [[1, 2, 3, 0], - [5, 6, 7, 8], - [9,10,11,12]] - ); - for( const numUpper of [3,4,-1,-2] ) { + await la.bandPart(a, numLower, 2).array(), + [[1, 2, 3, 0], [5, 6, 7, 8], [9, 10, 11, 12]]); + for (const numUpper of [3, 4, -1, -2]) { expectArraysClose( - await la.bandPart(a,numLower,numUpper).array(), - [[1, 2, 3, 4], - [5, 6, 7, 8], - [9,10,11,12]] - ); + await la.bandPart(a, numLower, numUpper).array(), + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]); } } }); -}); // end bandPart +}); // end bandPart describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => { it('2x2, Array of Tensor1D', async () => { @@ -260,8 +223,8 @@ describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => { it('2x3, Matrix', async () => { const xs: Tensor2D = tf.randomNormal([2, 3], 0, 1, 'float32', 1); const y = tf.linalg.gramSchmidt(xs) as Tensor2D; - expectArraysClose( - await y.matMul(y.transpose()).array(), await tf.eye(2).array()); + const yT: Tensor2D = y.transpose(); + expectArraysClose(await y.matMul(yT).array(), await tf.eye(2).array()); }); it('3x2 Matrix throws Error', () => { diff --git a/tfjs-core/src/ops/transpose.ts b/tfjs-core/src/ops/transpose.ts index 0f9a178ea89..e6fc8bda6b4 100644 --- a/tfjs-core/src/ops/transpose.ts +++ b/tfjs-core/src/ops/transpose.ts @@ -20,7 +20,6 @@ import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import * as axis_util from './axis_util'; import {op} from './operation'; /** @@ -62,13 +61,10 @@ function transpose_(x: T|TensorLike, perm?: number[]): T { return $x.clone(); } - const der = (dy: T) => { - const undoPerm = axis_util.getUndoAxesPermutation(perm); - return {x: () => dy.transpose(undoPerm)}; - }; const attrs = {perm}; return ENGINE.runKernelFunc( - backend => backend.transpose($x, perm), {x: $x}, der, 'Transpose', attrs); + backend => backend.transpose($x, perm), {x: $x}, null /* gradient */, + 'Transpose', attrs); } export const transpose = op({transpose_}); diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index d3cc5fa4b3d..d398ba1944d 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -21,4 +21,5 @@ import './div_no_nan'; import './squared_difference'; import './tile'; import './one_hot'; +import './transpose'; import './pad'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index aa11375f069..92aa87067c0 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -23,8 +23,10 @@ import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; // (And kerma will always load the chain augmentor files). But this gives us // flexibility to change in future. -const CHAINED_OPS = - ['square', 'broadcastTo', 'tile', 'oneHot', 'div', 'divNoNan', 'pad']; +const CHAINED_OPS = [ + 'square', 'broadcastTo', 'tile', 'oneHot', 'div', 'divNoNan', 'transpose', + 'pad' +]; describeWithFlags('chained ops', ALL_ENVS, () => { it('all chained ops should exist on tensor ', async () => { diff --git a/tfjs-core/src/public/chained_ops/transpose.ts b/tfjs-core/src/public/chained_ops/transpose.ts new file mode 100644 index 00000000000..fc0bc73336c --- /dev/null +++ b/tfjs-core/src/public/chained_ops/transpose.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {transpose} from '../../ops/transpose'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + transpose(perm?: number[]): T; + } +} + +Tensor.prototype.transpose = function( + this: T, perm?: number[]): T { + return transpose(this, perm); +}; diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 193f234eb8b..a35522b2998 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -22,6 +22,7 @@ import {padV2GradConfig} from './gradients/PadV2_grad'; import {squareGradConfig} from './gradients/Square_grad'; import {squaredDifferenceGradConfig} from './gradients/SquaredDifference_grad'; import {tileGradConfig} from './gradients/Tile_grad'; +import {transposeGradConfig} from './gradients/Transpose_grad'; import {GradConfig} from './kernel_registry'; import {registerGradient} from './kernel_registry'; @@ -29,7 +30,7 @@ import {registerGradient} from './kernel_registry'; const gradConfigs: GradConfig[] = [ divGradConfig, squareGradConfig, squaredDifferenceGradConfig, broadcastToGradConfig, identityGradConfig, tileGradConfig, oneHotGradConfig, - padV2GradConfig + transposeGradConfig, padV2GradConfig ]; for (const gradientConfig of gradConfigs) { diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 848cc3a1711..c7036e2f3a5 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -232,7 +232,6 @@ export interface OpHandler { maximum(a: Tensor, b: Tensor|TensorLike): T; maximumStrict(a: T, b: T|TensorLike): T; squaredDifferenceStrict(a: T, b: T|TensorLike): T; - transpose(x: T, perm?: number[]): T; logicalNot(x: T): T; logicalAnd(a: Tensor, b: Tensor|TensorLike): T; logicalOr(a: Tensor, b: Tensor|TensorLike): T; @@ -968,10 +967,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.squaredDifferenceStrict(this, x); } - transpose(this: T, perm?: number[]): T { - this.throwIfDisposed(); - return opHandler.transpose(this, perm); - } // Compare ops. diff --git a/tfjs-layers/src/initializers_test.ts b/tfjs-layers/src/initializers_test.ts index 371998adb29..e28ef14a8d2 100644 --- a/tfjs-layers/src/initializers_test.ts +++ b/tfjs-layers/src/initializers_test.ts @@ -602,8 +602,9 @@ describeMathCPUAndGPU('Orthogonal Initializer', () => { const w = init.apply([2, 4], 'float32') as Tensor2D; expect(w.shape).toEqual([2, 4]); expect(w.dtype).toEqual('float32'); + const wT: Tensor2D = w.transpose(); // Assert that columns of w are orthogonal. - expectTensorsClose(w.matMul(w.transpose()), eye(2)); + expectTensorsClose(w.matMul(wT), eye(2)); }); it('64x64', () => { @@ -616,7 +617,8 @@ describeMathCPUAndGPU('Orthogonal Initializer', () => { expect(w.shape).toEqual([n, n]); expect(w.dtype).toEqual('float32'); // Assert that columns of w are orthogonal. - expectTensorsClose(w.matMul(w.transpose()), eye(n)); + const wT: Tensor2D = w.transpose(); + expectTensorsClose(w.matMul(wT), eye(n)); }); it('Does not leak', () => { const init = getInitializer('Orthogonal');