From ed8ab3a7236a6ba287b1c186be788ef9828d04c8 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 13 Mar 2023 10:56:28 +0800 Subject: [PATCH 1/2] webgpu: Remove the unneccessary conditions --- tfjs-backend-webgpu/src/binary_op_webgpu.ts | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index 8f240e9fcd5..ba96aefb5fe 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -70,13 +70,11 @@ export class BinaryOpProgram implements WebGPUProgram { this.variableComponents = [4, 4]; this.type = 'vec4'; } else if ( - (op === BinaryOpType.SUB || op === BinaryOpType.ADD || - op === BinaryOpType.MUL || op === BinaryOpType.DIV) && - ((aDivisibleBy4 && - (util.isScalarShape(bShape) || bShape[bShape.length - 1] === 1)) || - (bDivisibleBy4 && - (util.isScalarShape(aShape) || aShape[aShape.length - 1] === 1)))) { - this.type = 'custom'; + (aDivisibleBy4 && + (util.isScalarShape(bShape) || bShape[bShape.length - 1] === 1)) || + (bDivisibleBy4 && + (util.isScalarShape(aShape) || aShape[aShape.length - 1] === 1))) { + this.type = 'mix'; this.outputComponent = 4; this.variableComponents = aDivisibleBy4 ? [4, 1] : [1, 4]; } else { From 2ac96dc4b29237ca24a49d824871af1118c66977 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 13 Mar 2023 13:55:35 +0800 Subject: [PATCH 2/2] nits --- tfjs-backend-webgpu/src/binary_op_webgpu.ts | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index ba96aefb5fe..588fe9282a9 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -55,7 +55,7 @@ export class BinaryOpProgram implements WebGPUProgram { // used as uniform. this.lastDimensionSize = this.useSharedMemoryWithB ? bShape[0] : aShape[0]; - this.shaderKey = `binary_${this.type}_${op}_${this.lastDimensionSize}`; + this.shaderKey = `binary_${op}_${this.lastDimensionSize}`; this.type = 'shared'; // This is an experimental value when using shared memory. // Note that the maximum of workgroup X dimension is 256. @@ -68,21 +68,19 @@ export class BinaryOpProgram implements WebGPUProgram { if (aDivisibleBy4 && bDivisibleBy4) { this.outputComponent = 4; this.variableComponents = [4, 4]; - this.type = 'vec4'; } else if ( (aDivisibleBy4 && (util.isScalarShape(bShape) || bShape[bShape.length - 1] === 1)) || (bDivisibleBy4 && (util.isScalarShape(aShape) || aShape[aShape.length - 1] === 1))) { - this.type = 'mix'; this.outputComponent = 4; this.variableComponents = aDivisibleBy4 ? [4, 1] : [1, 4]; } else { this.outputComponent = 1; this.variableComponents = [1, 1]; - this.type = 'plain'; } - this.shaderKey = `binary_${this.type}_${op}_${this.variableComponents}`; + this.type = 'nonshared'; + this.shaderKey = `binary_${op}_${this.variableComponents}`; // TODO(jiajia.qin@intel.com): Heuristically select a good work group // size. this.workgroupSize = [128, 1, 1];