diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index 8f240e9fcd5..588fe9282a9 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -55,7 +55,7 @@ export class BinaryOpProgram implements WebGPUProgram { // used as uniform. this.lastDimensionSize = this.useSharedMemoryWithB ? bShape[0] : aShape[0]; - this.shaderKey = `binary_${this.type}_${op}_${this.lastDimensionSize}`; + this.shaderKey = `binary_${op}_${this.lastDimensionSize}`; this.type = 'shared'; // This is an experimental value when using shared memory. // Note that the maximum of workgroup X dimension is 256. @@ -68,23 +68,19 @@ export class BinaryOpProgram implements WebGPUProgram { if (aDivisibleBy4 && bDivisibleBy4) { this.outputComponent = 4; this.variableComponents = [4, 4]; - this.type = 'vec4'; } else if ( - (op === BinaryOpType.SUB || op === BinaryOpType.ADD || - op === BinaryOpType.MUL || op === BinaryOpType.DIV) && - ((aDivisibleBy4 && - (util.isScalarShape(bShape) || bShape[bShape.length - 1] === 1)) || - (bDivisibleBy4 && - (util.isScalarShape(aShape) || aShape[aShape.length - 1] === 1)))) { - this.type = 'custom'; + (aDivisibleBy4 && + (util.isScalarShape(bShape) || bShape[bShape.length - 1] === 1)) || + (bDivisibleBy4 && + (util.isScalarShape(aShape) || aShape[aShape.length - 1] === 1))) { this.outputComponent = 4; this.variableComponents = aDivisibleBy4 ? [4, 1] : [1, 4]; } else { this.outputComponent = 1; this.variableComponents = [1, 1]; - this.type = 'plain'; } - this.shaderKey = `binary_${this.type}_${op}_${this.variableComponents}`; + this.type = 'nonshared'; + this.shaderKey = `binary_${op}_${this.variableComponents}`; // TODO(jiajia.qin@intel.com): Heuristically select a good work group // size. this.workgroupSize = [128, 1, 1];