diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index 3b17dd30850..8f240e9fcd5 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -25,14 +25,13 @@ export class BinaryOpProgram implements WebGPUProgram { dispatch: [number, number, number]; dispatchLayout: {x: number[]}; outputComponent: number; - isVec4: boolean; op: BinaryOpType; outputShape: number[]; shaderKey: string; size = true; variableNames = ['A', 'B']; workgroupSize: [number, number, number]; - workPerThread: number; + variableComponents: number[]; private lastDimensionSize: number; private useSharedMemoryWithA: boolean; @@ -50,46 +49,57 @@ export class BinaryOpProgram implements WebGPUProgram { bShape.length <= 1 && aShape.length > 1 && bShape[0] < 128; if (this.useSharedMemoryWithA || this.useSharedMemoryWithB) { - this.isVec4 = false; + this.outputComponent = 1; + this.variableComponents = [1, 1]; // lastDimensionSize is used as sharedBuf array size, so can not be // used as uniform. this.lastDimensionSize = this.useSharedMemoryWithB ? bShape[0] : aShape[0]; - this.shaderKey = `binary_${this.type}_${op}_${this.lastDimensionSize}_${ - this.useSharedMemoryWithB}`; + this.shaderKey = `binary_${this.type}_${op}_${this.lastDimensionSize}`; this.type = 'shared'; // This is an experimental value when using shared memory. // Note that the maximum of workgroup X dimension is 256. this.workgroupSize = [256, 1, 1]; - this.workPerThread = 1; } else { - if (util.arraysEqual(aShape, bShape) && - util.sizeFromShape(aShape) % 4 === 0) { - this.isVec4 = true; + const aDivisibleBy4 = + aShape.length > 0 && aShape[aShape.length - 1] % 4 === 0; + const bDivisibleBy4 = + bShape.length > 0 && bShape[bShape.length - 1] % 4 === 0; + if (aDivisibleBy4 && bDivisibleBy4) { this.outputComponent = 4; + this.variableComponents = [4, 4]; this.type = 'vec4'; - this.workPerThread = 4; + } else if ( + (op === BinaryOpType.SUB || op === BinaryOpType.ADD || + op === BinaryOpType.MUL || op === BinaryOpType.DIV) && + ((aDivisibleBy4 && + (util.isScalarShape(bShape) || bShape[bShape.length - 1] === 1)) || + (bDivisibleBy4 && + (util.isScalarShape(aShape) || aShape[aShape.length - 1] === 1)))) { + this.type = 'custom'; + this.outputComponent = 4; + this.variableComponents = aDivisibleBy4 ? [4, 1] : [1, 4]; } else { - this.isVec4 = false; + this.outputComponent = 1; + this.variableComponents = [1, 1]; this.type = 'plain'; - this.workPerThread = 1; } - this.shaderKey = `binary_${this.type}_${op}`; + this.shaderKey = `binary_${this.type}_${op}_${this.variableComponents}`; // TODO(jiajia.qin@intel.com): Heuristically select a good work group // size. this.workgroupSize = [128, 1, 1]; } this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workgroupSize, - [this.workPerThread, 1, 1]); + [this.outputComponent, 1, 1]); } getUserCode(): string { let userCode; - const dType = this.isVec4 ? 'vec4' : 'f32'; + const dType = this.outputComponent === 4 ? 'vec4' : 'f32'; const opFnStr = ` fn binaryOperation(a : ${dType}, b : ${dType}) -> ${dType} { - ${getBinaryOpString(this.op, this.isVec4)} + ${getBinaryOpString(this.op, this.outputComponent === 4)} }; `; @@ -126,8 +136,9 @@ export class BinaryOpProgram implements WebGPUProgram { ${opFnStr} ${main('index')} { if (index < uniforms.size) { - let a = getAByOutputIndex(index); - let b = getBByOutputIndex(index); + let coords = getCoordsFromIndex(index * ${this.outputComponent}); + let a = ${dType}(getAByOutputCoords(coords)); + let b = ${dType}(getBByOutputCoords(coords)); setOutputAtIndex(index, binaryOperation(a, b)); } }