-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[WebGPU EP] fixes bugs in NCHW version of instance norm operator #25092
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR addresses several bugs in the NCHW implementation of the WebGPU instance normalization operator by propagating the correct logical shapes, fixing dispatch sizing, ensuring tests run in NCHW layout, and cleaning up typos and implicit typing.
- Add
is_nhwc
flag toDefaultWebGpuExecutionProvider
and use it to enable NCHW support in shader generation. - Extend CPU tests with two new WebGPU NCHW tests for instance normalization.
- Rename misspelled variables, correct the output size calculation, and adjust the channel scale/shift tensor shape in
instance_norm.cc
.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
File | Description |
---|---|
test/util/include/default_providers.h | Added is_nhwc default parameter to WebGPU provider factory. |
test/util/default_providers.cc | Implemented is_nhwc logic for preferred layout configuration. |
test/providers/cpu/nn/instance_norm_op_test.cc | Added two WebGPU-backed NCHW instance norm tests under USE_WEBGPU . |
core/providers/webgpu/nn/instance_norm.cc | Fixed typo (hight →height ), removed incorrect component division in output_size , and updated channel scale/shift tensor shape. |
…rosoft#25092) The instance norm operator suffered from the following issues that this PR addresses: 1. If {2, 80, 2} is the tensor shape, then there are 320 numbers. {2, 80, 1} is the logical shape where each element is a vec2, so there are 320 numbers as well. The InstanceNorm\<false\> code path was not passing the logical shape into the shader generation function causing incorrect output. 2. The output_size was being divided by components which affects how many workers are dispatched. In the case of components=4, 75% of outputs for the InstanceNorm\<false\> code path were not updated and remained 0 causing correctness issues. 3. All the tests, including ones explicitly marked NCHW, were being run on the preferred data layout (NHWC). 4. Typos and some implicit typing was fixed as well. P.S. Fixes pyannote model (cherry picked from commit 3eeff82)
The instance norm operator suffered from the following issues that this PR addresses:
P.S. Fixes pyannote model