From d108d4c7d121520c88e022827f415133e350f860 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Wed, 4 Jan 2023 14:41:34 +0800 Subject: [PATCH] webgpu: Remove isVectorA check The original changes may result multiple threads write to the same address. --- tfjs-backend-webgpu/src/matmul_packed_webgpu.ts | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts index 6bd029cd50b..69a8a25b240 100644 --- a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts +++ b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts @@ -144,7 +144,7 @@ const calculateResultSnippet = export function makeMatMulPackedVec4Source( workPerThread: number[], workgroupSize: [number, number, number], transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32, - isVectorA = false, broadcastBatch = false): string { + broadcastBatch = false): string { const tileAOuter = workgroupSize[1] * workPerThread[1]; const tileBOuter = workgroupSize[0] * workPerThread[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -172,10 +172,10 @@ export function makeMatMulPackedVec4Source( ${main()} { let localRow = i32(localId.y); - let tileRow = ${isVectorA ? '0' : `localRow * ${rowPerThread}`}; + let tileRow = localRow * ${rowPerThread}; let tileCol = i32(localId.x); - let globalRow = ${isVectorA ? '0' : `i32(globalId.y) * ${rowPerThread}`}; + let globalRow = i32(globalId.y) * ${rowPerThread}; let globalCol = i32(globalId.x); let batch = ${splitK ? '0' : 'i32(globalId.z)'}; let batchA = ${ @@ -314,8 +314,7 @@ export function makeMatMulPackedSource( var BCached : array; for (var k = 0; k < ${tileInner}; k++) { for (var inner = 0; inner < ${colPerThread}; inner++) { - BCached[inner] = mm_Bsub[k][localCol + inner * ${ - workgroupSize[0]}]; + BCached[inner] = mm_Bsub[k][localCol + inner * ${workgroupSize[0]}]; } for (var innerRow = 0; innerRow < ${rowPerThread}; innerRow++) { let ACached = ${ @@ -333,8 +332,7 @@ export function makeMatMulPackedSource( for (var innerRow = 0; innerRow < ${rowPerThread}; innerRow++) { let gRow = globalRowStart + localRow + innerRow * ${workgroupSize[1]}; for (var innerCol = 0; innerCol < ${colPerThread}; innerCol++) { - let gCol = globalColStart + localCol + innerCol * ${ - workgroupSize[0]}; + let gCol = globalColStart + localCol + innerCol * ${workgroupSize[0]}; mm_write(batch, gRow, gCol, acc[innerRow][innerCol]); } } @@ -599,7 +597,7 @@ export class MatMulPackedProgram implements WebGPUProgram { this.isVec4 ? makeMatMulPackedVec4Source( this.elementsPerThread, this.workgroupSize, this.transposeA, - this.tileInner, false, null, this.isVectorA, true) : + this.tileInner, false, null, true) : (this.isVectorA ? makeVectorMatrixProductSource( this.workgroupSize, this.transposeA) : makeMatMulPackedSource(