Skip to content

Commit 412597b

Browse files
committed
Address Jiajia's comments
1 parent 312aa01 commit 412597b

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

tfjs-backend-webgpu/src/pool2d_webgpu.ts

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ export class Pool2DProgram implements WebGPUProgram {
3333
poolType: 'max'|'avg';
3434
size = true;
3535
computePositions: boolean;
36-
positionStr = '';
36+
flattenPositions: boolean;
37+
includeBatchIndex: boolean;
3738

3839
constructor(
3940
convInfo: backend_util.Conv2DInfo, poolType: 'max'|'avg',
@@ -44,36 +45,36 @@ export class Pool2DProgram implements WebGPUProgram {
4445
}
4546

4647
this.outputShape = convInfo.outShape;
47-
this.computePositions = computePositions;
48-
4948
this.dispatchLayout = flatDispatchLayout(this.outputShape);
5049
this.dispatch = computeDispatch(
5150
this.dispatchLayout, this.outputShape, this.workgroupSize);
5251

53-
if (computePositions) {
54-
this.positionStr = flattenPositions ?
55-
(includeBatchIndex ?
56-
`((batch * uniforms.xShape[1] + xR) * uniforms.xShape[2] + xC) * uniforms.xShape[3] + d` :
57-
`(xR * uniforms.xShape[2] + xC) * uniforms.xShape[3] + d`) :
58-
`wR * uniforms.filterDims.y + wC`;
59-
}
60-
6152
this.poolType = poolType;
53+
this.computePositions = computePositions;
54+
this.flattenPositions = flattenPositions;
55+
this.includeBatchIndex = includeBatchIndex;
6256
this.shaderKey = `pool2D_${poolType}_${computePositions}_${
6357
flattenPositions}_${includeBatchIndex}`;
6458
}
6559

6660
getUserCode(): string {
67-
let updateSnippet = this.computePositions ?
68-
`let currMaxValue = mix(value, maxValue, maxValueFound);
69-
if (value >= currMaxValue) {
70-
maxValue = value;
71-
maxValueFound = 1.0;
72-
maxPosition = ${this.positionStr};
73-
}` :
74-
`resultValue = max(value, resultValue);`;
61+
let updateSnippet: string;
7562
if (this.poolType === 'avg') {
7663
updateSnippet = `resultValue = resultValue + value; count = count + 1.0;`;
64+
} else if (this.computePositions) {
65+
const positionStr = this.flattenPositions ?
66+
(this.includeBatchIndex ?
67+
`((batch * uniforms.xShape[1] + xR) * uniforms.xShape[2] + xC) * uniforms.xShape[3] + d` :
68+
`(xR * uniforms.xShape[2] + xC) * uniforms.xShape[3] + d`) :
69+
`wR * uniforms.filterDims.y + wC`;
70+
updateSnippet = `let currMaxValue = mix(value, maxValue, maxValueFound);
71+
if (value >= currMaxValue) {
72+
maxValue = value;
73+
maxValueFound = 1.0;
74+
maxPosition = ${positionStr};
75+
}`;
76+
} else {
77+
updateSnippet = `resultValue = max(value, resultValue);`;
7778
}
7879

7980
let returnValue = `resultValue`;
@@ -92,15 +93,14 @@ export class Pool2DProgram implements WebGPUProgram {
9293
let xCCorner = xRCCorner.y;
9394
9495
${
95-
this.computePositions ? `var maxValue = 0.0;
96+
this.computePositions ?
97+
`var maxValue = 0.0;
9698
var maxValueFound = 0.0;
9799
var maxPosition = 0;` :
98-
''}
100+
`var resultValue = ${
101+
this.poolType === 'avg' ? '0.0' : '-1.0 / pow(10.0, -20.0)'};`}
99102
100-
var resultValue = ${
101-
this.poolType === 'avg' ? '0.0' : '-1.0 / pow(10.0, -20.0)'};
102103
var count = 0.0;
103-
104104
for (var wR = 0; wR < uniforms.filterDims.x; wR = wR + uniforms.dilation.x) {
105105
let xR = xRCorner + wR;
106106

0 commit comments

Comments
 (0)