Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 5 additions & 30 deletions tfjs-core/src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import {buffer, scalar, tensor, tensor4d} from '../../ops/ops';
import * as scatter_nd_util from '../../ops/scatter_nd_util';
import * as selu_util from '../../ops/selu_util';
import {computeFlatOffset, computeOutShape, isSliceContinous} from '../../ops/slice_util';
import {transpose} from '../../ops/transpose';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';
import {BackendValues, DataType, DataValues, NumericDataType, Rank, ShapeMap, TypedArray, upcastType} from '../../types';
import * as util from '../../util';
Expand Down Expand Up @@ -2111,32 +2112,6 @@ export class MathBackendCPU extends KernelBackend {
return buffer.toTensor() as T;
}

transpose<T extends Tensor>(x: T, perm: number[]): T {
assertNotComplex(x, 'transpose');

const newShape: number[] = new Array(x.rank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[perm[i]];
}
const values = this.readSync(x.dataId) as TypedArray;
const result = buffer(newShape, x.dtype);

const xBuf = this.bufferSync(x);
for (let i = 0; i < x.size; ++i) {
const loc = xBuf.indexToLoc(i);

// Permute location.
const newLoc: number[] = new Array(loc.length);
for (let i = 0; i < newLoc.length; i++) {
newLoc[i] = loc[perm[i]];
}

const newIndex = result.locToIndex(newLoc);
result.values[newIndex] = values[i];
}
return result.toTensor() as T;
}

gather<T extends Tensor>(x: T, indices: Tensor1D, axis: number): T {
assertNotComplex([x, indices], 'gather');

Expand Down Expand Up @@ -2174,8 +2149,7 @@ export class MathBackendCPU extends KernelBackend {
const sliceSize =
array_ops_util.getSliceSize(reshapedPermuted, crops, blockShape.length);

return x.reshape(reshaped)
.transpose(permuted)
return transpose(x.reshape(reshaped), permuted)
.reshape(reshapedPermuted)
.slice(sliceBeginCoords, sliceSize) as T;
}
Expand All @@ -2201,8 +2175,9 @@ export class MathBackendCPU extends KernelBackend {
const flattenShape = array_ops_util.getReshapedPermuted(
paddedX.shape, blockShape, prod, false);

return paddedX.reshape(reshapedPaddedShape)
.transpose(permutedReshapedPaddedPermutation)
return transpose(
paddedX.reshape(reshapedPaddedShape),
permutedReshapedPaddedPermutation)
.reshape(flattenShape) as T;
}

Expand Down
49 changes: 49 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Transpose.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Transpose, TransposeAttrs, TransposeInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import {TypedArray} from '../../../types';
import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';

import {transposeImpl} from './Transpose_impl';

export const transposeConfig: KernelConfig = {
kernelName: Transpose,
backendName: 'cpu',
kernelFunc: ({inputs, attrs, backend}) => {
const {x} = inputs as TransposeInputs;
const {perm} = attrs as {} as TransposeAttrs;
const cpuBackend = backend as MathBackendCPU;

assertNotComplex(x, 'transpose');

const xRank = x.shape.length;

const newShape: number[] = new Array(xRank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[perm[i]];
}

const values = cpuBackend.data.get(x.dataId).values as TypedArray;
const result = transposeImpl(values, x.shape, x.dtype, perm, newShape);

const dataId = cpuBackend.write(result, newShape, x.dtype);
return {dataId, shape: newShape, dtype: x.dtype};
}
};
45 changes: 45 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Transpose_impl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {DataType, NumericDataType, TypedArray} from '../../../types';
import * as util from '../../../util';

export function transposeImpl(
xVals: TypedArray, xShape: number[], dtype: DataType, perm: number[],
newShape: number[]): TypedArray {
const xSize = util.sizeFromShape(xShape);
const xRank = xShape.length;
const xStrides = util.computeStrides(xShape);
const newStrides = util.computeStrides(newShape);

const result = util.getTypedArrayFromDType(
dtype as NumericDataType, util.sizeFromShape(newShape));

for (let i = 0; i < xSize; ++i) {
const loc = util.indexToLoc(i, xRank, xStrides);

// Permute location.
const newLoc: number[] = new Array(loc.length);
for (let i = 0; i < newLoc.length; i++) {
newLoc[i] = loc[perm[i]];
}

const newIndex = util.locToIndex(newLoc, xRank, newStrides);
result[newIndex] = xVals[i];
}
return result;
}
4 changes: 3 additions & 1 deletion tfjs-core/src/backends/cpu/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ import {divConfig} from './kernels/Div';
import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5';
import {squareConfig} from './kernels/Square';
import {squaredDifferenceConfig} from './kernels/SquaredDifference';
import {transposeConfig} from './kernels/Transpose';

// List all kernel configs here
const kernelConfigs: KernelConfig[] = [
nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig
nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig,
transposeConfig
];

for (const kernelConfig of kernelConfigs) {
Expand Down
36 changes: 13 additions & 23 deletions tfjs-core/src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import * as segment_util from '../../ops/segment_util';
import * as slice_util from '../../ops/slice_util';
import {softmax} from '../../ops/softmax';
import {range, scalar, tensor} from '../../ops/tensor_ops';
import {transpose} from '../../ops/transpose';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
import {BackendValues, DataType, DataTypeMap, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types';
import * as util from '../../util';
Expand Down Expand Up @@ -125,8 +126,6 @@ import * as tex_util from './tex_util';
import {TextureData, TextureUsage} from './tex_util';
import {TextureManager} from './texture_manager';
import {TileProgram} from './tile_gpu';
import {TransposeProgram} from './transpose_gpu';
import {TransposePackedProgram} from './transpose_packed_gpu';
import * as unary_op from './unaryop_gpu';
import {UnaryOpProgram} from './unaryop_gpu';
import * as unary_packed_op from './unaryop_packed_gpu';
Expand Down Expand Up @@ -653,12 +652,13 @@ export class MathBackendWebGL extends KernelBackend {
TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more
sustainable strategy for optimizing backend execution of ops.
*/
private shouldExecuteOnCPU(
inputs: Tensor[], sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD): boolean {
shouldExecuteOnCPU(
inputs: TensorInfo[],
sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD): boolean {
return this.getCPUBackend() != null &&
inputs.every(
input => this.texData.get(input.dataId).texture == null &&
input.size < sizeThreshold);
util.sizeFromShape(input.shape) < sizeThreshold);
}

getGPGPUContext(): GPGPUContext {
Expand Down Expand Up @@ -821,10 +821,10 @@ export class MathBackendWebGL extends KernelBackend {
if ((outerShapeA === 1 || outerShapeB === 1) &&
sharedDim > MATMUL_SHARED_DIM_THRESHOLD) {
if (transposeA) {
a = a.transpose([0, 2, 1]);
a = transpose(a, [0, 2, 1]);
}
if (transposeB) {
b = b.transpose([0, 2, 1]);
b = transpose(b, [0, 2, 1]);
}

const a3D = outerShapeB === 1 ? a : a.as3D(batch, sharedDim, 1);
Expand Down Expand Up @@ -969,16 +969,6 @@ export class MathBackendWebGL extends KernelBackend {
return this.compileAndRun(program, [x]);
}

transpose<T extends Tensor>(x: T, perm: number[]): T {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.transpose(x, perm);
}
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new TransposePackedProgram(x.shape, perm) :
new TransposeProgram(x.shape, perm);
return this.compileAndRun(program, [x]);
}

gather<T extends Tensor>(x: T, indices: Tensor1D, axis: number): T {
if (this.shouldExecuteOnCPU([x, indices])) {
return this.cpuBackend.gather(x, indices, axis);
Expand All @@ -1005,8 +995,7 @@ export class MathBackendWebGL extends KernelBackend {
const sliceSize =
array_ops_util.getSliceSize(reshapedPermuted, crops, blockShape.length);

return x.reshape(reshaped)
.transpose(permuted)
return transpose(x.reshape(reshaped), permuted)
.reshape(reshapedPermuted)
.slice(sliceBeginCoords, sliceSize) as T;
}
Expand Down Expand Up @@ -1037,8 +1026,9 @@ export class MathBackendWebGL extends KernelBackend {
const flattenShape = array_ops_util.getReshapedPermuted(
paddedX.shape, blockShape, prod, false);

return paddedX.reshape(reshapedPaddedShape)
.transpose(permutedReshapedPaddedPermutation)
return transpose(
paddedX.reshape(reshapedPaddedShape),
permutedReshapedPaddedPermutation)
.reshape(flattenShape) as T;
}

Expand Down Expand Up @@ -1127,7 +1117,7 @@ export class MathBackendWebGL extends KernelBackend {
const permutation = axis_util.getAxesPermutation([axis], x.rank);
let permutedX = x;
if (permutation != null) {
permutedX = x.transpose(permutation);
permutedX = transpose(x, permutation);
axis = axis_util.getInnerMostAxes(1, x.rank)[0];
}

Expand All @@ -1141,7 +1131,7 @@ export class MathBackendWebGL extends KernelBackend {
a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments)
.reshape(outShape);
if (permutation != null) {
result = result.transpose(axis_util.getUndoAxesPermutation(permutation));
result = transpose(result, axis_util.getUndoAxesPermutation(permutation));
}
return result;
}
Expand Down
54 changes: 54 additions & 0 deletions tfjs-core/src/backends/webgl/kernels/Transpose.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {TypedArray} from '../../../../src/types';
import {transposeImpl as cpuTranspose} from '../../../backends/cpu/kernels/Transpose_impl';
import {Transpose, TransposeAttrs, TransposeInputs} from '../../../kernel_names';
import {KernelConfig, TensorInfo} from '../../../kernel_registry';
import {MathBackendWebGL} from '../backend_webgl';
import {transposeImpl} from './Transpose_impl';

export const transposeConfig: KernelConfig = {
kernelName: Transpose,
backendName: 'webgl',
kernelFunc: ({inputs, attrs, backend}) => {
const {x} = inputs as TransposeInputs;
const {perm} = attrs as {} as TransposeAttrs;
const webglBackend = backend as MathBackendWebGL;

const xRank = x.shape.length;

const newShape: number[] = new Array(xRank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[perm[i]];
}

let out: TensorInfo;
if (webglBackend.shouldExecuteOnCPU([x])) {
const xTexData = webglBackend.texData.get(x.dataId);
const values = xTexData.values as TypedArray;
const outValues = cpuTranspose(values, x.shape, x.dtype, perm, newShape);

out = webglBackend.makeTensorInfo(newShape, x.dtype);
const outData = webglBackend.texData.get(out.dataId);
outData.values = outValues;
} else {
out = transposeImpl(x, perm, webglBackend);
}
return out;
}
};
30 changes: 30 additions & 0 deletions tfjs-core/src/backends/webgl/kernels/Transpose_impl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {env} from '../../../environment';
import {TensorInfo} from '../../../kernel_registry';
import {MathBackendWebGL} from '../backend_webgl';
import {TransposeProgram} from '../transpose_gpu';
import {TransposePackedProgram} from '../transpose_packed_gpu';

export function transposeImpl(
x: TensorInfo, perm: number[], backend: MathBackendWebGL): TensorInfo {
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new TransposePackedProgram(x.shape, perm) :
new TransposeProgram(x.shape, perm);
return backend.runWebGLProgram(program, [x], x.dtype);
}
2 changes: 2 additions & 0 deletions tfjs-core/src/backends/webgl/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {fromPixelsConfig} from './kernels/FromPixels';
import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5';
import {squareConfig} from './kernels/Square';
import {squaredDifferenceConfig} from './kernels/SquaredDifference';
import {transposeConfig} from './kernels/Transpose';

// List all kernel configs here
const kernelConfigs: KernelConfig[] = [
Expand All @@ -29,6 +30,7 @@ const kernelConfigs: KernelConfig[] = [
nonMaxSuppressionV5Config,
squareConfig,
squaredDifferenceConfig,
transposeConfig,
];

for (const kernelConfig of kernelConfigs) {
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/backends/webgl/webgl_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,9 @@ describeWithFlags('gramSchmidt-non-tiny', WEBGL_ENVS, () => {
// can complete in the timeout limit of the unit test.
const xs: Tensor2D = tf.randomUniform([8, 16]);
const y = tf.linalg.gramSchmidt(xs) as Tensor2D;
const yTransposed: Tensor2D = y.transpose();
expectArraysClose(
await y.matMul(y.transpose()).data(), await tf.eye(8).data());
await y.matMul(yTransposed).data(), await tf.eye(8).data());
});
});

Expand Down
Loading