Skip to content

Commit 14e4a79

Browse files
committed
[webgpu] Enable test for fusedConv2D
FIX #2660
1 parent 695f4aa commit 14e4a79

File tree

4 files changed

+31
-10
lines changed

4 files changed

+31
-10
lines changed

tfjs-backend-webgpu/src/backend_webgpu.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,14 @@ export class WebGPUBackend extends KernelBackend {
805805
convInfo.strideHeight, convInfo.strideWidth
806806
];
807807

808-
return this.compileAndRun(program, [input, filter], output, dimensions);
808+
const inputs: Tensor[] = [input, filter];
809+
if (hasBias) {
810+
inputs.push(bias);
811+
}
812+
if (hasPreluActivationWeights) {
813+
inputs.push(preluActivationWeights);
814+
}
815+
return this.compileAndRun(program, inputs, output, dimensions);
809816
}
810817

811818
private argMinMaxReduce(x: Tensor, axis: number, reduceType: 'min'|'max'):

tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,12 @@ export class Conv2DMMProgram implements WebGPUProgram {
5858
const tileAOuter = this.workGroupSize[1] * elementsPerThread[1];
5959
const tileBOuter = this.workGroupSize[0] * elementsPerThread[0];
6060
const tileInner = tileAOuter > tileBOuter ? tileAOuter : tileBOuter;
61-
util.assert(tileInner % this.workGroupSize[0] === 0 &&
62-
tileInner % this.workGroupSize[1] === 0,
63-
() => 'tileInner must be multiple of workgroupsize.x and workgroupsize.y');
61+
util.assert(
62+
tileInner % this.workGroupSize[0] === 0 &&
63+
tileInner % this.workGroupSize[1] === 0,
64+
() =>
65+
// tslint:disable-next-line: max-line-length
66+
'tileInner must be multiple of workgroupsize.x and workgroupsize.y');
6467
const tileSizeA = [tileAOuter, tileInner];
6568
const tileSizeB = [tileInner, tileBOuter];
6669
const dimAOuter = this.outputShape[1] * this.outputShape[2];
@@ -72,8 +75,8 @@ export class Conv2DMMProgram implements WebGPUProgram {
7275
`x[getFlatIndex(coord, xShape)]` :
7376
`coordsInBounds(coord, xShape) ? x[getFlatIndex(coord, xShape)] : 0`;
7477
const fitB = tilesFitEvenlyIntoShape(tileSizeB, [dimInner, dimBOuter]);
75-
const sampleB =
76-
fitB ? `W[row * dimBOuter + col]` :
78+
const sampleB = fitB ?
79+
`W[row * dimBOuter + col]` :
7780
`coordsInBounds(ivec2(row, col), ivec2(dimInner, dimBOuter)) ?
7881
W[row * dimBOuter + col] : 0`;
7982

@@ -90,7 +93,7 @@ export class Conv2DMMProgram implements WebGPUProgram {
9093
}`;
9194
} else {
9295
activationSnippet = `
93-
float activation(float x) {
96+
float activation(float a) {
9497
${activation}
9598
}
9699
`;
@@ -155,6 +158,7 @@ export class Conv2DMMProgram implements WebGPUProgram {
155158
mm_matMul(dimAOuter, dimInner, dimBOuter);
156159
}
157160
`;
158-
this.shaderKey = `conv2dmm'${elementsPerThread.join('')}${fitA}${fitB}`;
161+
this.shaderKey = `conv2dmm'${elementsPerThread.join('')}${fitA}${fitB}${
162+
addBiasSnippet}${activationSnippet}`;
159163
}
160164
}

tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ import {WebGPUProgram} from './webgpu_program';
2323

2424
export const RELU = 'return max(a, 0.0);';
2525
export const RELU6 = 'return (a < 0.0) ? 0.0 : min(6.0, a);';
26-
export const LINEAR = `return x;`;
27-
export const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;
26+
export const LINEAR = `return a;`;
27+
export const ELU = `return (a >= 0.0) ? a : (exp(a) - 1.0);`;
2828

2929
export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`;
3030
export const ABS = `return abs(a);`;

tfjs-backend-webgpu/src/setup_test.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,16 @@ const TEST_FILTERS: TestFilter[] = [
126126
'fused', // Not yet implemented.
127127
]
128128
},
129+
{
130+
include: 'fused conv2d',
131+
excludes: [
132+
'im2row with prelu', // Actual != expected.
133+
'pointwise with prelu', // Actual != expected.
134+
'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', // conv2dDerInput not yet
135+
// implemented
136+
'fused matmul with relu6', // step not yet implemented
137+
]
138+
},
129139
{
130140
include: 'fromPixels',
131141
excludes: [

0 commit comments

Comments
 (0)