@@ -33,7 +33,8 @@ export class Pool2DProgram implements WebGPUProgram {
3333 poolType : 'max' | 'avg' ;
3434 size = true ;
3535 computePositions : boolean ;
36- positionStr = '' ;
36+ flattenPositions : boolean ;
37+ includeBatchIndex : boolean ;
3738
3839 constructor (
3940 convInfo : backend_util . Conv2DInfo , poolType : 'max' | 'avg' ,
@@ -44,36 +45,36 @@ export class Pool2DProgram implements WebGPUProgram {
4445 }
4546
4647 this . outputShape = convInfo . outShape ;
47- this . computePositions = computePositions ;
48-
4948 this . dispatchLayout = flatDispatchLayout ( this . outputShape ) ;
5049 this . dispatch = computeDispatch (
5150 this . dispatchLayout , this . outputShape , this . workgroupSize ) ;
5251
53- if ( computePositions ) {
54- this . positionStr = flattenPositions ?
55- ( includeBatchIndex ?
56- `((batch * uniforms.xShape[1] + xR) * uniforms.xShape[2] + xC) * uniforms.xShape[3] + d` :
57- `(xR * uniforms.xShape[2] + xC) * uniforms.xShape[3] + d` ) :
58- `wR * uniforms.filterDims.y + wC` ;
59- }
60-
6152 this . poolType = poolType ;
53+ this . computePositions = computePositions ;
54+ this . flattenPositions = flattenPositions ;
55+ this . includeBatchIndex = includeBatchIndex ;
6256 this . shaderKey = `pool2D_${ poolType } _${ computePositions } _${
6357 flattenPositions } _${ includeBatchIndex } `;
6458 }
6559
6660 getUserCode ( ) : string {
67- let updateSnippet = this . computePositions ?
68- `let currMaxValue = mix(value, maxValue, maxValueFound);
69- if (value >= currMaxValue) {
70- maxValue = value;
71- maxValueFound = 1.0;
72- maxPosition = ${ this . positionStr } ;
73- }` :
74- `resultValue = max(value, resultValue);` ;
61+ let updateSnippet : string ;
7562 if ( this . poolType === 'avg' ) {
7663 updateSnippet = `resultValue = resultValue + value; count = count + 1.0;` ;
64+ } else if ( this . computePositions ) {
65+ const positionStr = this . flattenPositions ?
66+ ( this . includeBatchIndex ?
67+ `((batch * uniforms.xShape[1] + xR) * uniforms.xShape[2] + xC) * uniforms.xShape[3] + d` :
68+ `(xR * uniforms.xShape[2] + xC) * uniforms.xShape[3] + d` ) :
69+ `wR * uniforms.filterDims.y + wC` ;
70+ updateSnippet = `let currMaxValue = mix(value, maxValue, maxValueFound);
71+ if (value >= currMaxValue) {
72+ maxValue = value;
73+ maxValueFound = 1.0;
74+ maxPosition = ${ positionStr } ;
75+ }` ;
76+ } else {
77+ updateSnippet = `resultValue = max(value, resultValue);` ;
7778 }
7879
7980 let returnValue = `resultValue` ;
@@ -92,15 +93,14 @@ export class Pool2DProgram implements WebGPUProgram {
9293 let xCCorner = xRCCorner.y;
9394
9495 ${
95- this . computePositions ? `var maxValue = 0.0;
96+ this . computePositions ?
97+ `var maxValue = 0.0;
9698 var maxValueFound = 0.0;
9799 var maxPosition = 0;` :
98- '' }
100+ `var resultValue = ${
101+ this . poolType === 'avg' ? '0.0' : '-1.0 / pow(10.0, -20.0)' } ;`}
99102
100- var resultValue = ${
101- this . poolType === 'avg' ? '0.0' : '-1.0 / pow(10.0, -20.0)' } ;
102103 var count = 0.0;
103-
104104 for (var wR = 0; wR < uniforms.filterDims.x; wR = wR + uniforms.dilation.x) {
105105 let xR = xRCorner + wR;
106106
0 commit comments