Skip to content

[WebGPU EP] fixes bugs in NCHW version of instance norm operator #25092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions onnxruntime/core/providers/webgpu/nn/instance_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader)

shader.MainFunctionBody() << " let batch = workgroup_idx / uniforms.x_shape[1];\n"
<< " let channel = workgroup_idx % uniforms.x_shape[1];\n"
<< " let hight = uniforms.x_shape[2];\n"
<< " // initialize workgroup memory<< \n"
<< " let height = uniforms.x_shape[2];\n"
<< " // initialize workgroup memory\n"
<< " var sum = f32_val_t(0);\n"
<< " var squared_sum = f32_val_t(0);\n"
<< " for (var h = local_idx; h < hight; h += workgroup_size) {\n"
<< " for (var h = local_idx; h < height; h += workgroup_size) {\n"
<< " let indices = x_indices_t(batch, channel, h);\n"
<< " let value = f32_val_t(" << input.GetByIndices("indices") << ");\n"
<< " sum += value;\n"
Expand All @@ -46,8 +46,8 @@ Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader)
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (local_idx == 0) {\n"
<< " let sum_final = " << SumVector("workgroup_shared_sum[0]", components_) << " / f32(hight * " << components_ << ");\n"
<< " let squared_sum_final = " << SumVector("workgroup_shared_squared_sum[0]", components_) << " / f32(hight * " << components_ << ");\n"
<< " let sum_final = " << SumVector("workgroup_shared_sum[0]", components_) << " / f32(height * " << components_ << ");\n"
<< " let squared_sum_final = " << SumVector("workgroup_shared_squared_sum[0]", components_) << " / f32(height * " << components_ << ");\n"
<< " let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + f32(" << std::to_string(epsilon_) << "));\n"
<< " let channel_scale = inv_std_dev * f32(" << scale.GetByOffset("channel") << ");\n"
<< " let channel_shift = f32(" << bias.GetByOffset("channel") << ") - sum_final * channel_scale;\n"
Expand Down Expand Up @@ -194,17 +194,19 @@ Status InstanceNorm<false>::ComputeInternal(ComputeContext& context) const {
const auto spatial_size = input->Shape().SizeFromDimension(2);
Tensor channel_scale_shift;
ORT_RETURN_IF_ERROR(ComputeChannelScaleAndShift(context, input, scale, bias, epsilon_, &channel_scale_shift));
const auto output_shape(input_shape_vector);
TensorShape output_shape(input_shape_vector);
Tensor* output = context.Output(0, output_shape);
const auto components = GetMaxComponents(spatial_size);
TensorShapeVector modified_input_shape_vector = {batch_size, channels, spatial_size / components};
TensorShape modified_input_shape(modified_input_shape_vector);
TensorShape modified_output_shape(modified_input_shape_vector);
auto output_size = (modified_output_shape.Size() + components - 1) / components;
auto output_size = modified_output_shape.Size();
TensorShapeVector channel_scale_shift_shape_vector = {batch_size, channels, 1};
TensorShape reduced_channel_scale_shift_shape(channel_scale_shift_shape_vector);
InstanceNormProgram program;
program
.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank, modified_input_shape, components},
{&channel_scale_shift, ProgramTensorMetadataDependency::TypeAndRank, channel_scale_shift.Shape(), 2}})
{&channel_scale_shift, ProgramTensorMetadataDependency::TypeAndRank, reduced_channel_scale_shift_shape, 2}})
.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, modified_output_shape, components})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({static_cast<uint32_t>(output_size)});
Expand Down
47 changes: 47 additions & 0 deletions onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/util/include/default_providers.h"

using namespace std;
namespace onnxruntime {
Expand Down Expand Up @@ -290,5 +291,51 @@ TEST(InstanceNormalizationOpTest, InstanceNormNCHW) {
});
}

#ifdef USE_WEBGPU
TEST(InstanceNormalizationOpTest, InstanceNormNCHW_webgpu) {
OpTester test("InstanceNormalization");
test.AddAttribute("epsilon", 0.009999999776482582f);

vector<float> input = {1.0f, 2.0f, 3.0f, 2.0f, 2.0f, 2.0f};
vector<int64_t> input_dims = {1, 2, 1, 3};
test.AddInput<float>("input", input_dims, input);

vector<float> scale = {1.0f, 1.0f};
vector<int64_t> scale_dims = {2};
test.AddInput<float>("scale", scale_dims, scale);

vector<float> B = {0.0f, 2.0f};
vector<int64_t> B_dims = {2};
test.AddInput<float>("B", B_dims, B);

vector<float> expected_output = {-1.21566f, 0.0f, 1.21566f, 2.0f, 2.0f, 2.0f};
test.AddOutput<float>("Y", input_dims, expected_output);

test.ConfigEp(DefaultWebGpuExecutionProvider(false)).RunWithConfig();
}

TEST(InstanceNormalizationOpTest, InstanceNormNCHW_webgpu_2) {
OpTester test("InstanceNormalization");
test.AddAttribute("epsilon", 0.009999999776482582f);

vector<float> input = {1.0f, 2.0f, 3.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f};
vector<int64_t> input_dims = {1, 2, 2, 2};
test.AddInput<float>("input", input_dims, input);

vector<float> scale = {1.0f, 1.0f};
vector<int64_t> scale_dims = {2};
test.AddInput<float>("scale", scale_dims, scale);

vector<float> B = {0.0f, 2.0f};
vector<int64_t> B_dims = {2};
test.AddInput<float>("B", B_dims, B);

vector<float> expected_output = {-1.40028f, 0.0f, 1.40028f, 0.0f, 2.0f, 2.0f, 2.0f, 2.0f};
test.AddOutput<float>("Y", input_dims, expected_output);

test.ConfigEp(DefaultWebGpuExecutionProvider(false)).RunWithConfig();
}
#endif

} // namespace test
} // namespace onnxruntime
9 changes: 8 additions & 1 deletion onnxruntime/test/util/default_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,15 +307,22 @@ std::unique_ptr<IExecutionProvider> DefaultXnnpackExecutionProvider() {
#endif
}

std::unique_ptr<IExecutionProvider> DefaultWebGpuExecutionProvider() {
std::unique_ptr<IExecutionProvider> DefaultWebGpuExecutionProvider(bool is_nhwc) {
#ifdef USE_WEBGPU
ConfigOptions config_options{};
// Disable storage buffer cache
ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode,
webgpu::options::kBufferCacheMode_Disabled)
.IsOK());
if (!is_nhwc) {
// Enable NCHW support
ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kPreferredLayout,
webgpu::options::kPreferredLayout_NCHW)
.IsOK());
}
return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider();
#else
ORT_UNUSED_PARAMETER(is_nhwc);
return nullptr;
#endif
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/util/include/default_providers.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ std::unique_ptr<IExecutionProvider> DefaultQnnExecutionProvider();
std::unique_ptr<IExecutionProvider> QnnExecutionProviderWithOptions(const ProviderOptions& options,
const SessionOptions* session_options = nullptr);
std::unique_ptr<IExecutionProvider> DefaultXnnpackExecutionProvider();
std::unique_ptr<IExecutionProvider> DefaultWebGpuExecutionProvider();
std::unique_ptr<IExecutionProvider> DefaultWebGpuExecutionProvider(bool is_nhwc = true);
std::unique_ptr<IExecutionProvider> DefaultCannExecutionProvider();
std::unique_ptr<IExecutionProvider> DefaultDmlExecutionProvider();

Expand Down
Loading