@@ -58,9 +58,12 @@ export class Conv2DMMProgram implements WebGPUProgram {
5858 const tileAOuter = this . workGroupSize [ 1 ] * elementsPerThread [ 1 ] ;
5959 const tileBOuter = this . workGroupSize [ 0 ] * elementsPerThread [ 0 ] ;
6060 const tileInner = tileAOuter > tileBOuter ? tileAOuter : tileBOuter ;
61- util . assert ( tileInner % this . workGroupSize [ 0 ] === 0 &&
62- tileInner % this . workGroupSize [ 1 ] === 0 ,
63- ( ) => 'tileInner must be multiple of workgroupsize.x and workgroupsize.y' ) ;
61+ util . assert (
62+ tileInner % this . workGroupSize [ 0 ] === 0 &&
63+ tileInner % this . workGroupSize [ 1 ] === 0 ,
64+ ( ) =>
65+ // tslint:disable-next-line: max-line-length
66+ 'tileInner must be multiple of workgroupsize.x and workgroupsize.y' ) ;
6467 const tileSizeA = [ tileAOuter , tileInner ] ;
6568 const tileSizeB = [ tileInner , tileBOuter ] ;
6669 const dimAOuter = this . outputShape [ 1 ] * this . outputShape [ 2 ] ;
@@ -72,8 +75,8 @@ export class Conv2DMMProgram implements WebGPUProgram {
7275 `x[getFlatIndex(coord, xShape)]` :
7376 `coordsInBounds(coord, xShape) ? x[getFlatIndex(coord, xShape)] : 0` ;
7477 const fitB = tilesFitEvenlyIntoShape ( tileSizeB , [ dimInner , dimBOuter ] ) ;
75- const sampleB =
76- fitB ? `W[row * dimBOuter + col]` :
78+ const sampleB = fitB ?
79+ `W[row * dimBOuter + col]` :
7780 `coordsInBounds(ivec2(row, col), ivec2(dimInner, dimBOuter)) ?
7881 W[row * dimBOuter + col] : 0` ;
7982
@@ -90,7 +93,7 @@ export class Conv2DMMProgram implements WebGPUProgram {
9093 }` ;
9194 } else {
9295 activationSnippet = `
93- float activation(float x ) {
96+ float activation(float a ) {
9497 ${ activation }
9598 }
9699 ` ;
@@ -155,6 +158,7 @@ export class Conv2DMMProgram implements WebGPUProgram {
155158 mm_matMul(dimAOuter, dimInner, dimBOuter);
156159 }
157160 ` ;
158- this . shaderKey = `conv2dmm'${ elementsPerThread . join ( '' ) } ${ fitA } ${ fitB } ` ;
161+ this . shaderKey = `conv2dmm'${ elementsPerThread . join ( '' ) } ${ fitA } ${ fitB } ${
162+ addBiasSnippet } ${ activationSnippet } `;
159163 }
160164}
0 commit comments