Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@
import './flags_webgpu';

import {backend_util, DataStorage, DataType, engine, env, findBackend, KernelBackend, Rank, RecursiveArray, ShapeMap, slice_util, Tensor, Tensor2D, Tensor3D, Tensor4D, TimingInfo, util} from '@tensorflow/tfjs-core';
import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile';
// TODO(xing.xu): use FusedConv2DConfig from backend_util:
// https://github.com/tensorflow/tfjs/issues/2471
// tslint:disable-next-line: no-imports-from-dist
import {FusedConv2DConfig} from '@tensorflow/tfjs-core/dist/ops/fused_util';
// TODO: Import reduce_util from backend_util with next release of core.
import {computeOptimalWindowSize} from '@tensorflow/tfjs-core/src/ops/reduce_util';
// TODO: import sumOutType directly from '@tensorflow/tfjs-core' with next
// release of core.
import {sumOutType} from '@tensorflow/tfjs-core/src/types';
import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile';

import {BufferManager} from './buffer_manager';
import {ArgMinMaxProgram} from './kernels/argminmax_webgpu';
Expand All @@ -41,6 +46,7 @@ import {MatMulPackedProgram} from './kernels/matmul_packed_webgpu';
import {MatMulProgram} from './kernels/matmul_webgpu';
import {MaxPoolProgram} from './kernels/maxpool_webgpu';
import {PadProgram} from './kernels/pad_webgpu';
import {ReduceProgram} from './kernels/reduce_webgpu';
import {ResizeBilinearProgram} from './kernels/resize_bilinear_webgpu';
import {SelectProgram} from './kernels/select_webgpu';
import {SliceProgram} from './kernels/slice_webgpu';
Expand Down Expand Up @@ -825,6 +831,45 @@ export class WebGPUBackend extends KernelBackend {
return this.argMinMaxReduce(x, axis, 'max');
}

private reduce(x: Tensor2D, reduceType: 'max'|'min'|'sum', dtype: DataType):
Tensor2D {
const batchSize = x.shape[0];
const inSize = x.shape[1];
const windowSize = computeOptimalWindowSize(inSize);
const reduceInfo = {windowSize, inSize, batchSize};
const program = new ReduceProgram(reduceInfo, reduceType);
const output = this.makeOutputArray(program.outputShape, dtype);
return this.compileAndRun(program, [x], output);
}

max(x: Tensor, axes: number[]): Tensor {
backend_util.assertAxesAreInnerMostDims('max', axes, x.rank);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(x.shape, axes);
const reduceSize = util.sizeFromShape(reduceShape);
const a2D = x.as2D(-1, reduceSize);
return this.reduce(a2D, 'max', a2D.dtype).reshape(outShape);
}

min(x: Tensor, axes: number[]): Tensor {
backend_util.assertAxesAreInnerMostDims('min', axes, x.rank);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(x.shape, axes);
const reduceSize = util.sizeFromShape(reduceShape);
const a2D = x.as2D(-1, reduceSize);
return this.reduce(a2D, 'min', a2D.dtype).reshape(outShape);
}

sum(x: Tensor, axes: number[]): Tensor {
backend_util.assertAxesAreInnerMostDims('sum', axes, x.rank);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(x.shape, axes);
const reduceSize = util.sizeFromShape(reduceShape);
const a2D = x.as2D(-1, reduceSize);
const outputDType = sumOutType(x.dtype);
return this.reduce(a2D, 'sum', outputDType).reshape(outShape);
}

clip<T extends Tensor>(x: T, min: number, max: number): T {
const program = new ClipProgram(x.shape, min, max);
return this.compileAndRun(program, [x]);
Expand Down Expand Up @@ -916,6 +961,14 @@ export class WebGPUBackend extends KernelBackend {
return this.compileAndRun(program, [x]);
}

abs<T extends Tensor>(x: T): T {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.abs(x);
}
const program = new UnaryOpProgram(x.shape, unary_op.ABS);
return this.compileAndRun(program, [x]);
}

prelu<T extends Tensor>(x: T, alpha: T): T {
const program = new BinaryOpProgram(binary_op.PRELU, x.shape, alpha.shape);
return this.compileAndRun(program, [x, alpha]);
Expand Down
119 changes: 119 additions & 0 deletions tfjs-backend-webgpu/src/kernels/reduce_webgpu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {backend_util, util} from '@tensorflow/tfjs-core';
// TODO : use backend_util.reduce_util with the next release of tfjs-core.
import {ReduceInfo} from '@tensorflow/tfjs-core/src/ops/reduce_util';
import {getCoordsDataType} from '../shader_preprocessor';
import {computeDispatch} from '../webgpu_util';
import {WebGPUProgram} from './webgpu_program';

export class ReduceProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
userCode: string;
dispatchLayout: {x: number[], y: number[]};
dispatch: [number, number, number];
workGroupSize: [number, number, number];
variableNames = ['x'];

constructor(reduceInfo: ReduceInfo, reduceType: 'max'|'min'|'sum') {
const inputShape = [reduceInfo.batchSize, reduceInfo.inSize];
const [outputShape, reduceShape] =
backend_util.computeOutAndReduceShapes(inputShape, [1]);
this.outputShape = outputShape.length === 0 ? [1] : outputShape;
const reduceSize = util.sizeFromShape(reduceShape);

const reductionFactor = 2;
const xMaxThreads = 1024;
const xThreads =
Math.min(Math.ceil(reduceSize / reductionFactor), xMaxThreads);

this.workGroupSize = [xThreads, 1, 1];
this.dispatchLayout = {x: [], y: this.outputShape.map((d, i) => i)};
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
const reduceInSharedMemory = xThreads > 1;

const minmaxOp = `
if (candidate ${reduceType === 'min' ? '<' : '>'} bestValue
&& !isnan(candidate))
{ bestValue = candidate; }
`;
const sumOp = ' bestValue += candidate; ';
const op =
(reduceType === 'min' || reduceType === 'max') ? minmaxOp : sumOp;

const sharedMemorySnippet = `
shared float xBestValues[WorkGroupSize];
`;
const sharedMemoryReduceSnippet = `
xBestValues[gl_LocalInvocationID.x] = bestValue;
${reduceType === 'sum' ? 'bestValue=0;' : ' '}
int currentSize = WorkGroupSize;
while (currentSize > 1) {
barrier();
for (int w = 0; w < ${reductionFactor}; ++w) {
int i = int(gl_LocalInvocationID.x) * ${reductionFactor} + w;
if (i < currentSize) {
float candidate = xBestValues[i];
${op}
}
}
xBestValues[gl_LocalInvocationID.x] = bestValue;
currentSize = DIV_CEIL(currentSize, ${reductionFactor});
${reduceType === 'sum' ? 'if(currentSize > 1) bestValue=0;' : ''}
}
if (gl_LocalInvocationID.x == 0) {
setOutput(flatOutputIndex, bestValue);
}
`;

const outputCoordsType = getCoordsDataType(this.outputShape.length);

this.userCode = `
#define DIV_CEIL(x, y) (((x) - 1) / (y) + 1)
const int WorkGroupSize = int(gl_WorkGroupSize.x);
${reduceInSharedMemory ? sharedMemorySnippet : ''}
int getOffset() {
const ${outputCoordsType} outputCoords = getOutputCoords();
int offset = ${
this.outputShape.length === 1 ? 'outputCoords' :
'outputCoords[0]'} * xShape[1];
return offset;
}
void main() {
const int offset= getOffset();
${
reduceType === 'sum' ? 'float bestValue = 0;' :
'float bestValue = x[offset];'}
const int Length = ${inputShape.length === 1 ? 'xShape' : 'xShape[1]'};
const int WorkPerThread = DIV_CEIL(Length, WorkGroupSize);
for (int w = 0; w < WorkPerThread; ++w) {
int i = int(gl_GlobalInvocationID.x) * WorkPerThread + w;
if (i < Length) {
float candidate = x[offset + i];
${(reduceType === 'max' || reduceType === 'min') ? minmaxOp : sumOp}
}
}
const int flatOutputIndex = int(gl_GlobalInvocationID.y);
${
reduceInSharedMemory ? sharedMemoryReduceSnippet :
'setOutput(flatOutputIndex, bestValue);'}
}
`;
}
}
1 change: 1 addition & 0 deletions tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export const LINEAR = `return x;`;
export const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;

export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`;
export const ABS = `return abs(a);`;

export class UnaryOpProgram implements WebGPUProgram {
outputShape: number[];
Expand Down
39 changes: 39 additions & 0 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,45 @@ const TEST_FILTERS: TestFilter[] = [
'grad', // 'depthwiseConv2DDerFilter' not yet implemented, slice not yet
// implemented
]
},
{
include: 'Reduction: max',
excludes: [
'5D', // Rank 5 is not yet implemented.
'6D', // Rank 5 is not yet implemented.
'accepts tensor with bool', // Actual != Expected.
'gradient', // zerosLike not yet implemented.
]
},
{
include: 'Reduction: min',
excludes: [
'5D', // Rank 5 is not yet implemented.
'6D', // Rank 5 is not yet implemented.
'accepts tensor with bool', // Actual != Expected.
'gradient', // zerosLike not yet implemented.
]
},
{
include: 'Reduction: sum',
excludes: [
'dtype bool', // not support dtype bool yet.
'5D', // Rank 5 is not yet implemented.
'6D', // Rank 5 is not yet implemented.
'accepts tensor with bool', // Actual != Expected.
'gradient', // zerosLike not yet implemented.
]
},
{
include: 'abs',
excludes: [
'complex', // No complex support yet.
'5D', // Rank 5 is not yet implemented.
'6D', // Rank 5 is not yet implemented.
'accepts tensor with bool', // Actual != Expected.
'gradient', // zerosLike not yet implemented.
'absoluteDifference', // absoluteDifference not yet implemented
]
}
];

Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/backends/backend_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export * from '../ops/broadcast_util';
export * from '../ops/concat_util';
export * from '../ops/conv_util';
export {Activation, FusedConv2DConfig} from '../ops/fused_util';

export * from '../ops/reduce_util';
export {BackendValues, TypedArray, upcastType, PixelData} from '../types';
export {MemoryInfo, TimingInfo} from '../engine';

Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
export {SGDOptimizer} from './optimizers/sgd_optimizer';
export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor';
export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types';
export {DataType, DataTypeMap, DataValues, Rank, RecursiveArray, ShapeMap, TensorLike} from './types';
export {DataType, DataTypeMap, DataValues, Rank, RecursiveArray, ShapeMap, sumOutType, TensorLike} from './types';
Copy link
Contributor

@dsmilkov dsmilkov Jan 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this introduces a top-level tf.sumOutType method, which we don't want as part of our API since we are aligning with TF Python. All the other exports from './types' are interfaces/types that do not exist at runtime. We should instead expose sumOutType as part of the backend_util namespace.


export * from './ops/ops';
export {LSTMCellFunc} from './ops/lstm';
Expand Down