From fa2d5f2dccbc194073a089670c96ed8675744422 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Wed, 1 Mar 2023 11:23:13 +0800 Subject: [PATCH 1/5] webgpu: Optimize broadcast mul/div --- tfjs-backend-webgpu/src/binary_op_webgpu.ts | 34 ++++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index 3b17dd30850..543ea3bedf1 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -25,7 +25,6 @@ export class BinaryOpProgram implements WebGPUProgram { dispatch: [number, number, number]; dispatchLayout: {x: number[]}; outputComponent: number; - isVec4: boolean; op: BinaryOpType; outputShape: number[]; shaderKey: string; @@ -33,6 +32,7 @@ export class BinaryOpProgram implements WebGPUProgram { variableNames = ['A', 'B']; workgroupSize: [number, number, number]; workPerThread: number; + variableComponents?: number[]; private lastDimensionSize: number; private useSharedMemoryWithA: boolean; @@ -50,7 +50,7 @@ export class BinaryOpProgram implements WebGPUProgram { bShape.length <= 1 && aShape.length > 1 && bShape[0] < 128; if (this.useSharedMemoryWithA || this.useSharedMemoryWithB) { - this.isVec4 = false; + this.outputComponent = 1; // lastDimensionSize is used as sharedBuf array size, so can not be // used as uniform. this.lastDimensionSize = @@ -63,14 +63,21 @@ export class BinaryOpProgram implements WebGPUProgram { this.workgroupSize = [256, 1, 1]; this.workPerThread = 1; } else { - if (util.arraysEqual(aShape, bShape) && - util.sizeFromShape(aShape) % 4 === 0) { - this.isVec4 = true; + const aEqualsB = util.arraysEqual(aShape, bShape); + if ((op === BinaryOpType.MUL || op === BinaryOpType.DIV) && !aEqualsB && + aShape.length > 0 && aShape[aShape.length - 1] % 4 === 0 && + (util.isScalarShape(bShape) || bShape[bShape.length - 1] === 1)) { + this.type = 'custom'; + this.outputComponent = 4; + this.variableComponents = [4, 1]; + this.workPerThread = 4; + + } else if (aEqualsB && util.sizeFromShape(aShape) % 4 === 0) { this.outputComponent = 4; this.type = 'vec4'; this.workPerThread = 4; } else { - this.isVec4 = false; + this.outputComponent = 1; this.type = 'plain'; this.workPerThread = 1; } @@ -86,10 +93,10 @@ export class BinaryOpProgram implements WebGPUProgram { getUserCode(): string { let userCode; - const dType = this.isVec4 ? 'vec4' : 'f32'; + const dType = this.outputComponent === 4 ? 'vec4' : 'f32'; const opFnStr = ` fn binaryOperation(a : ${dType}, b : ${dType}) -> ${dType} { - ${getBinaryOpString(this.op, this.isVec4)} + ${getBinaryOpString(this.op, this.outputComponent === 4)} }; `; @@ -121,6 +128,17 @@ export class BinaryOpProgram implements WebGPUProgram { } } `; + } else if (this.type === 'custom') { + userCode = ` + ${main('index')} { + if (index < uniforms.size) { + let a = getAByOutputIndex(index); + let b = getBByOutputIndex(index * ${this.outputComponent}); + setOutputAtIndex(index, ${ + this.op === BinaryOpType.MUL ? 'a * b' : 'a / b'}); + } + } + `; } else { userCode = ` ${opFnStr} From 34e92e9995ce330dd615735dd87f16ce5e5305b2 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Thu, 2 Mar 2023 15:58:23 +0800 Subject: [PATCH 2/5] refactor --- tfjs-backend-webgpu/src/binary_op_webgpu.ts | 52 +++++++++++---------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index 543ea3bedf1..4f4c0962ec4 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -18,7 +18,7 @@ import {backend_util, util} from '@tensorflow/tfjs-core'; import {BinaryOpType, getBinaryOpString} from './binary_op_util'; -import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; +import {getMainHeaderString as main, typeSnippet, WebGPUProgram} from './webgpu_program'; import {computeDispatch, flatDispatchLayout} from './webgpu_util'; export class BinaryOpProgram implements WebGPUProgram { @@ -55,8 +55,7 @@ export class BinaryOpProgram implements WebGPUProgram { // used as uniform. this.lastDimensionSize = this.useSharedMemoryWithB ? bShape[0] : aShape[0]; - this.shaderKey = `binary_${this.type}_${op}_${this.lastDimensionSize}_${ - this.useSharedMemoryWithB}`; + this.shaderKey = `binary_${this.type}_${op}_${this.lastDimensionSize}`; this.type = 'shared'; // This is an experimental value when using shared memory. // Note that the maximum of workgroup X dimension is 256. @@ -64,15 +63,22 @@ export class BinaryOpProgram implements WebGPUProgram { this.workPerThread = 1; } else { const aEqualsB = util.arraysEqual(aShape, bShape); - if ((op === BinaryOpType.MUL || op === BinaryOpType.DIV) && !aEqualsB && - aShape.length > 0 && aShape[aShape.length - 1] % 4 === 0 && - (util.isScalarShape(bShape) || bShape[bShape.length - 1] === 1)) { + const aDivisibleBy4 = + aShape.length > 0 && aShape[aShape.length - 1] % 4 === 0; + const bDivisibleBy4 = + bShape.length > 0 && bShape[bShape.length - 1] % 4 === 0; + if ((op === BinaryOpType.SUB || op === BinaryOpType.ADD || + op === BinaryOpType.MUL || op === BinaryOpType.DIV) && + !aEqualsB && + ((aDivisibleBy4 && + (util.isScalarShape(bShape) || bShape[bShape.length - 1] === 1)) || + (bDivisibleBy4 && + (util.isScalarShape(aShape) || aShape[aShape.length - 1] === 1)))) { this.type = 'custom'; this.outputComponent = 4; - this.variableComponents = [4, 1]; + this.variableComponents = aDivisibleBy4 ? [4, 1] : [1, 4]; this.workPerThread = 4; - - } else if (aEqualsB && util.sizeFromShape(aShape) % 4 === 0) { + } else if (aDivisibleBy4 && bDivisibleBy4) { this.outputComponent = 4; this.type = 'vec4'; this.workPerThread = 4; @@ -81,7 +87,7 @@ export class BinaryOpProgram implements WebGPUProgram { this.type = 'plain'; this.workPerThread = 1; } - this.shaderKey = `binary_${this.type}_${op}`; + this.shaderKey = `binary_${this.type}_${op}_${this.variableComponents}`; // TODO(jiajia.qin@intel.com): Heuristically select a good work group // size. this.workgroupSize = [128, 1, 1]; @@ -93,9 +99,15 @@ export class BinaryOpProgram implements WebGPUProgram { getUserCode(): string { let userCode; - const dType = this.outputComponent === 4 ? 'vec4' : 'f32'; const opFnStr = ` - fn binaryOperation(a : ${dType}, b : ${dType}) -> ${dType} { + fn binaryOperation(a : ${ + typeSnippet( + this.variableComponents ? this.variableComponents[0] : + this.outputComponent)}, b : ${ + typeSnippet( + this.variableComponents ? this.variableComponents[1] : + this.outputComponent)}) -> ${ + typeSnippet(this.outputComponent)} { ${getBinaryOpString(this.op, this.outputComponent === 4)} }; `; @@ -128,24 +140,14 @@ export class BinaryOpProgram implements WebGPUProgram { } } `; - } else if (this.type === 'custom') { - userCode = ` - ${main('index')} { - if (index < uniforms.size) { - let a = getAByOutputIndex(index); - let b = getBByOutputIndex(index * ${this.outputComponent}); - setOutputAtIndex(index, ${ - this.op === BinaryOpType.MUL ? 'a * b' : 'a / b'}); - } - } - `; } else { userCode = ` ${opFnStr} ${main('index')} { if (index < uniforms.size) { - let a = getAByOutputIndex(index); - let b = getBByOutputIndex(index); + let coords = getCoordsFromIndex(index * ${this.outputComponent}); + let a = getAByOutputCoords(coords); + let b = getBByOutputCoords(coords); setOutputAtIndex(index, binaryOperation(a, b)); } } From 6d177a346744e95dad66c7c84f1acd89b46016c3 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Fri, 3 Mar 2023 15:48:59 +0800 Subject: [PATCH 3/5] nits --- tfjs-backend-webgpu/src/binary_op_webgpu.ts | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index 4f4c0962ec4..c17b7ed5bc3 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -31,8 +31,7 @@ export class BinaryOpProgram implements WebGPUProgram { size = true; variableNames = ['A', 'B']; workgroupSize: [number, number, number]; - workPerThread: number; - variableComponents?: number[]; + variableComponents: number[]; private lastDimensionSize: number; private useSharedMemoryWithA: boolean; @@ -60,16 +59,18 @@ export class BinaryOpProgram implements WebGPUProgram { // This is an experimental value when using shared memory. // Note that the maximum of workgroup X dimension is 256. this.workgroupSize = [256, 1, 1]; - this.workPerThread = 1; } else { - const aEqualsB = util.arraysEqual(aShape, bShape); const aDivisibleBy4 = aShape.length > 0 && aShape[aShape.length - 1] % 4 === 0; const bDivisibleBy4 = bShape.length > 0 && bShape[bShape.length - 1] % 4 === 0; - if ((op === BinaryOpType.SUB || op === BinaryOpType.ADD || + if (aDivisibleBy4 && bDivisibleBy4) { + this.outputComponent = 4; + this.type = 'vec4'; + } else if ( + (op === BinaryOpType.SUB || op === BinaryOpType.ADD || op === BinaryOpType.MUL || op === BinaryOpType.DIV) && - !aEqualsB && + !util.arraysEqual(aShape, bShape) && ((aDivisibleBy4 && (util.isScalarShape(bShape) || bShape[bShape.length - 1] === 1)) || (bDivisibleBy4 && @@ -77,15 +78,9 @@ export class BinaryOpProgram implements WebGPUProgram { this.type = 'custom'; this.outputComponent = 4; this.variableComponents = aDivisibleBy4 ? [4, 1] : [1, 4]; - this.workPerThread = 4; - } else if (aDivisibleBy4 && bDivisibleBy4) { - this.outputComponent = 4; - this.type = 'vec4'; - this.workPerThread = 4; } else { this.outputComponent = 1; this.type = 'plain'; - this.workPerThread = 1; } this.shaderKey = `binary_${this.type}_${op}_${this.variableComponents}`; // TODO(jiajia.qin@intel.com): Heuristically select a good work group @@ -94,7 +89,7 @@ export class BinaryOpProgram implements WebGPUProgram { } this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workgroupSize, - [this.workPerThread, 1, 1]); + [this.outputComponent, 1, 1]); } getUserCode(): string { From c5f0a11b80cda9b54e35c36b19f3db3c35f2e0d3 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Tue, 7 Mar 2023 15:58:20 +0800 Subject: [PATCH 4/5] Fix potential issues. --- tfjs-backend-webgpu/src/binary_op_webgpu.ts | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index c17b7ed5bc3..772ab4fd2c4 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -18,7 +18,7 @@ import {backend_util, util} from '@tensorflow/tfjs-core'; import {BinaryOpType, getBinaryOpString} from './binary_op_util'; -import {getMainHeaderString as main, typeSnippet, WebGPUProgram} from './webgpu_program'; +import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; import {computeDispatch, flatDispatchLayout} from './webgpu_util'; export class BinaryOpProgram implements WebGPUProgram { @@ -70,7 +70,6 @@ export class BinaryOpProgram implements WebGPUProgram { } else if ( (op === BinaryOpType.SUB || op === BinaryOpType.ADD || op === BinaryOpType.MUL || op === BinaryOpType.DIV) && - !util.arraysEqual(aShape, bShape) && ((aDivisibleBy4 && (util.isScalarShape(bShape) || bShape[bShape.length - 1] === 1)) || (bDivisibleBy4 && @@ -94,15 +93,9 @@ export class BinaryOpProgram implements WebGPUProgram { getUserCode(): string { let userCode; + const dType = this.outputComponent === 4 ? 'vec4' : 'f32'; const opFnStr = ` - fn binaryOperation(a : ${ - typeSnippet( - this.variableComponents ? this.variableComponents[0] : - this.outputComponent)}, b : ${ - typeSnippet( - this.variableComponents ? this.variableComponents[1] : - this.outputComponent)}) -> ${ - typeSnippet(this.outputComponent)} { + fn binaryOperation(a : ${dType}, b : ${dType}) -> ${dType} { ${getBinaryOpString(this.op, this.outputComponent === 4)} }; `; @@ -141,8 +134,8 @@ export class BinaryOpProgram implements WebGPUProgram { ${main('index')} { if (index < uniforms.size) { let coords = getCoordsFromIndex(index * ${this.outputComponent}); - let a = getAByOutputCoords(coords); - let b = getBByOutputCoords(coords); + let a = ${dType}(getAByOutputCoords(coords)); + let b = ${dType}(getBByOutputCoords(coords)); setOutputAtIndex(index, binaryOperation(a, b)); } } From 65a6f6d6c4bc6649a8f22337a2fa760726bfa926 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Wed, 8 Mar 2023 15:02:48 +0800 Subject: [PATCH 5/5] address Yang's comment --- tfjs-backend-webgpu/src/binary_op_webgpu.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index 772ab4fd2c4..8f240e9fcd5 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -50,6 +50,7 @@ export class BinaryOpProgram implements WebGPUProgram { if (this.useSharedMemoryWithA || this.useSharedMemoryWithB) { this.outputComponent = 1; + this.variableComponents = [1, 1]; // lastDimensionSize is used as sharedBuf array size, so can not be // used as uniform. this.lastDimensionSize = @@ -66,6 +67,7 @@ export class BinaryOpProgram implements WebGPUProgram { bShape.length > 0 && bShape[bShape.length - 1] % 4 === 0; if (aDivisibleBy4 && bDivisibleBy4) { this.outputComponent = 4; + this.variableComponents = [4, 4]; this.type = 'vec4'; } else if ( (op === BinaryOpType.SUB || op === BinaryOpType.ADD || @@ -79,6 +81,7 @@ export class BinaryOpProgram implements WebGPUProgram { this.variableComponents = aDivisibleBy4 ? [4, 1] : [1, 4]; } else { this.outputComponent = 1; + this.variableComponents = [1, 1]; this.type = 'plain'; } this.shaderKey = `binary_${this.type}_${op}_${this.variableComponents}`;