-
Notifications
You must be signed in to change notification settings - Fork 2k
[WebGL] Implement packed ScatterND #7292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| if (flattenedIndex[0] == coords[0] || flattenedIndex[1] == coords[0] || | ||
| flattenedIndex[0] == coords[0] + 1 || flattenedIndex[1] == coords[0] + 1) { | ||
| vec4 updVals = ${updatesSnippet}; | ||
| if (flattenedIndex[0] == coords[0]) { | ||
| sum.xy += updVals.xy; | ||
| found.xy = vec2(1.); | ||
| } | ||
| if (flattenedIndex[1] == coords[0]) { | ||
| sum.xy += updVals.zw; | ||
| found.xy = vec2(1.); | ||
| } | ||
| if (flattenedIndex[0] == coords[0] + 1) { | ||
| sum.zw += updVals.xy; | ||
| found.zw = vec2(1.); | ||
| } | ||
| if (flattenedIndex[1] == coords[0] + 1) { | ||
| sum.zw += updVals.zw; | ||
| found.zw = vec2(1.); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 4 if-branches seem to hurt performance, so I tried to replace it with the vectorized codes as the following and pasted tests:
vec4 isMatched = 1. - vec4(bvec4(flattenedIndex[0] - coords[0],
flattenedIndex[1] - coords[0],
flattenedIndex[0] - coords[0] - 1,
flattenedIndex[1] - coords[0] - 1));
if (dot(isMatched, vec4(1.)) > 0.) {
vec4 updVals = ${updatesSnippet};
found += isMatched.xxzz + isMatched.yyww;
sum += updVals.xyxy * isMatched.xxzz + updVals.zwzw * isMatched.yyww;
}
...
setOutput(mix(${defaultValueSnippet}, sum, vec4(bvec4(found))));However, it does not show obvious improvements (the performace is recorded in 'Packed-ScatterND-vectorizedBranches' column in PR description), so we could use the 4-if-branches here, which has better readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The four if statements is certainly more readable. I'm a bit surprised it doesn't hurt performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the performance indifference happening on mobile as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the chart in PR description shows our mobile devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Nice improvement!
| if (flattenedIndex[0] == coords[0] || flattenedIndex[1] == coords[0] || | ||
| flattenedIndex[0] == coords[0] + 1 || flattenedIndex[1] == coords[0] + 1) { | ||
| vec4 updVals = ${updatesSnippet}; | ||
| if (flattenedIndex[0] == coords[0]) { | ||
| sum.xy += updVals.xy; | ||
| found.xy = vec2(1.); | ||
| } | ||
| if (flattenedIndex[1] == coords[0]) { | ||
| sum.xy += updVals.zw; | ||
| found.xy = vec2(1.); | ||
| } | ||
| if (flattenedIndex[0] == coords[0] + 1) { | ||
| sum.zw += updVals.xy; | ||
| found.zw = vec2(1.); | ||
| } | ||
| if (flattenedIndex[1] == coords[0] + 1) { | ||
| sum.zw += updVals.zw; | ||
| found.zw = vec2(1.); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The four if statements is certainly more readable. I'm a bit surprised it doesn't hurt performance.
pyu10055
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! is this feature enabled by default and fully tested with current tests?
Reviewable status:
complete! 1 of 1 approvals obtained (waiting on @Linchenn)
tfjs-backend-webgl/src/scatter_packed_gpu.ts line 3 at r1 (raw file):
/** * @license * Copyright 2018 Google LLC. All Rights Reserved.
2023
pyu10055
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 1 of 1 approvals obtained (waiting on @Linchenn and @mattsoulanille)
tfjs-backend-webgl/src/scatter_packed_gpu.ts line 76 at r7 (raw file):
} } if (flattenedIndex[0] == coords[0] || flattenedIndex[1] == coords[0] ||
is this check necessary? given the contained the branches will check again?
Linchenn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tested it locally, for both current changes and the vectorized branch optimization through:
yarn test --test_verbose_timeout_warnings --verbose_failures --nocache_test_results --//:grep='scatterND'
and tested 'USE-30' model correctness on the local benchmark tool.
I am running a nightly test now https://pantheon.corp.google.com/cloud-build/builds/f79d985c-2a59-4d2c-98c6-cc904231334a?project=learnjs-174218. Will not until it is passed.
To be more safe, we could add a 'WEBGL_PACK_SCATTERND' (default as false) at first?
Reviewable status:
complete! 1 of 1 approvals obtained (waiting on @mattsoulanille and @pyu10055)
tfjs-backend-webgl/src/scatter_packed_gpu.ts line 3 at r1 (raw file):
Previously, pyu10055 (Ping Yu) wrote…
2023
I have updated it. It's weird that reviewable tools sometimes does not show the latest changes.
tfjs-backend-webgl/src/scatter_packed_gpu.ts line 76 at r7 (raw file):
Previously, pyu10055 (Ping Yu) wrote…
is this check necessary? given the contained the branches will check again?
Yes, the immediate following line vec4 updVals = ${updatesSnippet}; is a read instruction. Only if this check is passed, the read instruction would be executed. This has visible, even though small ~2ms for the model, improvements.
Code quote:
vec4 updVals = ${updatesSnippet};
pyu10055
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 2 of 3 files at r4, 1 of 1 files at r5, 1 of 1 files at r8, all commit messages.
Reviewable status:complete! 2 of 1 approvals obtained (waiting on @mattsoulanille)
pyu10055
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, given it is guarded under WEBGL_PACK flag.
Reviewable status:
complete! 2 of 1 approvals obtained (waiting on @mattsoulanille)
pyu10055
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 1 of 1 approvals obtained (waiting on @Linchenn and @mattsoulanille)
tfjs-backend-webgl/src/scatter_packed_gpu.ts line 76 at r7 (raw file):
Previously, Linchenn wrote…
Yes, the immediate following line
vec4 updVals = ${updatesSnippet};is a read instruction. Only if this check is passed, the read instruction would be executed. This has visible, even though small ~2ms for the model, improvements.
can you move this to each child branches to avoid this top level branch statement?
Linchenn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 1 of 1 approvals obtained (waiting on @mattsoulanille and @pyu10055)
tfjs-backend-webgl/src/scatter_packed_gpu.ts line 76 at r7 (raw file):
Previously, pyu10055 (Ping Yu) wrote…
can you move this to each child branches to avoid this top level branch statement?
Merged mutual exclusive branches as if...else... . Thank you Ping for thinking deeply about it!
|
Just completed nightly tests. |
After this PR, 'USE-batch30' model would gain 30%~50% performance improvement. The 'Packed-ScatterND' column shows this PR's performance on 'USE-batch30'.
Specifically, on MacBook Pro, the time that 'USE-batch30' spends in 'ScatterNd' drops from 62.84 ms to 15.01ms.
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is