From ae7d22862be83c5ca5ed2d820a11fd8ab523766d Mon Sep 17 00:00:00 2001 From: Dun Date: Thu, 22 Nov 2018 15:42:28 +0800 Subject: [PATCH] Group Norm (#13843) Add group normalization operator. --- AUTHORS.md | 1 + paddle/fluid/API.spec | 1 + paddle/fluid/operators/group_norm_op.cc | 162 ++++++++++ paddle/fluid/operators/group_norm_op.cu | 292 ++++++++++++++++++ paddle/fluid/operators/group_norm_op.h | 197 ++++++++++++ python/paddle/fluid/layers/nn.py | 79 +++++ .../paddle/fluid/tests/unittests/op_test.py | 10 +- .../tests/unittests/test_group_norm_op.py | 143 +++++++++ 8 files changed, 880 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/group_norm_op.cc create mode 100644 paddle/fluid/operators/group_norm_op.cu create mode 100644 paddle/fluid/operators/group_norm_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_group_norm_op.py diff --git a/AUTHORS.md b/AUTHORS.md index 54a1097b50f7a..deafa641203ed 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -25,6 +25,7 @@ | kexinzhao | Ke-Xin Zhao | | kuke | Yi-Bing Liu | | lcy-seso | Ying Cao | +| cjld | Dun Liang | | lipeng-unisound | Peng Li | | liuyuan | Yuan Liu | | livc | Zhao Li | diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index da8941c351571..541c4db1fa091 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -103,6 +103,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)) +paddle.fluid.layers.group_norm ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)) paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, False, False)) paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc new file mode 100644 index 0000000000000..6322659b67f6a --- /dev/null +++ b/paddle/fluid/operators/group_norm_op.cc @@ -0,0 +1,162 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/group_norm_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DataLayout = framework::DataLayout; + +class GroupNormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of GroupNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), + "Output(Y) of GroupNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Mean"), + "Output(Mean) of GroupNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Variance"), + "Output(Variance) of GroupNormOp should not be null."); + + auto x_dim = ctx->GetInputDim("X"); + auto channel_num = x_dim[1]; + auto batch_size = x_dim[0]; + auto groups = ctx->Attrs().Get("groups"); + PADDLE_ENFORCE_LE( + groups, channel_num, + "'groups' must be less equal than the number of channels."); + PADDLE_ENFORCE_GE(groups, 1, "'groups' must be greater equal than 1."); + + if (ctx->HasInput("Scale")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], channel_num); + } + if (ctx->HasInput("Bias")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], channel_num); + } + + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); + ctx->SetOutputDim("Mean", {batch_size, groups}); + ctx->SetOutputDim("Variance", {batch_size, groups}); + ctx->ShareLoD("X", "Y"); + } +}; + +class GroupNormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddInput("Scale", + "Scale is a 1-dimensional tensor of size C" + "that is applied to the output.") + .AsDispensable(); + AddInput("Bias", + "Bias is a 1-dimensional tensor of size C " + "that is applied to the output") + .AsDispensable(); + AddOutput("Y", "Result after normalization."); + AddOutput("Mean", "Mean of each group.").AsIntermediate(); + AddOutput("Variance", "Variance of each group.").AsIntermediate(); + + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 1.0f, + "'epsilon' should be between 0.0 and 1.0."); + }); + AddAttr("groups", "The number of groups that divided from channels.") + .AddCustomChecker([](const int &groups) { + PADDLE_ENFORCE_GT(groups, 0, "'groups' should be greater than zero."); + }); + + AddComment(R"DOC( +Group Normalization + +Refer to `Group Normalization `_ +)DOC"); + } +}; + +class GroupNormGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + // check input + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of GroupNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Mean"), + "Input(Mean) of GroupNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Variance"), + "Input(Variance) of GroupNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) of GroupNormOp should not be null."); + + // check output + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + if (ctx->HasOutput(framework::GradVarName("Scale"))) { + ctx->SetOutputDim(framework::GradVarName("Scale"), + ctx->GetInputDim("Scale")); + } + if (ctx->HasOutput(framework::GradVarName("Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Bias"), + ctx->GetInputDim("Bias")); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + const auto *var = ctx.InputVar(framework::GradVarName("Y")); + if (var == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + const Tensor *t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } + if (t == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + return framework::OpKernelType(framework::ToDataType(t->type()), + ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(group_norm, ops::GroupNormOp, ops::GroupNormOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp); +REGISTER_OP_CPU_KERNEL( + group_norm, ops::GroupNormKernel, + ops::GroupNormKernel); +REGISTER_OP_CPU_KERNEL( + group_norm_grad, + ops::GroupNormGradKernel, + ops::GroupNormGradKernel); diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu new file mode 100644 index 0000000000000..27174630227c8 --- /dev/null +++ b/paddle/fluid/operators/group_norm_op.cu @@ -0,0 +1,292 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/operators/group_norm_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, + int imsize, int groups, + int group_size, T* mean, T* var) { + int gid = blockIdx.y; + int cid = blockIdx.x; + int bid = blockIdx.z; + int number = min(group_size, static_cast(C - gid * group_size)); + int ccid = gid * group_size + cid; + if (ccid >= C) return; + T x_mean = 0, x_var = 0; + for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { + T val = x[(bid * C + ccid) * imsize + imid]; + x_mean += val; + x_var += val * val; + } + x_mean /= number * imsize; + x_var /= number * imsize; + __shared__ T s_mem[2]; + if (threadIdx.x == 0) { + s_mem[0] = s_mem[1] = 0; + } + __syncthreads(); + paddle::platform::CudaAtomicAdd(&s_mem[0], x_mean); + paddle::platform::CudaAtomicAdd(&s_mem[1], x_var); + __syncthreads(); + if (threadIdx.x == 0) { + paddle::platform::CudaAtomicAdd(&mean[bid * groups + gid], s_mem[0]); + paddle::platform::CudaAtomicAdd(&var[bid * groups + gid], s_mem[1]); + } +} + +template +__global__ void GroupNormForward(const T* x, const T* mean, const T* var, + const T* scale, const T* bias, int N, int C, + int imsize, int groups, int group_size, + T epsilon, T* y, T* real_var) { + int gid = blockIdx.y; + int cid = blockIdx.x; + int bid = blockIdx.z; + int ccid = gid * group_size + cid; + if (ccid >= C) return; + T x_mean = mean[bid * groups + gid]; + T x_var = var[bid * groups + gid]; + x_var = x_var - x_mean * x_mean; + T var_inv = 1.0 / sqrt(x_var + epsilon); + if (cid == 0 && threadIdx.x == 0) real_var[bid * groups + gid] = x_var; + for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { + T val = x[(bid * C + ccid) * imsize + imid]; + val = (val - x_mean) * var_inv; + if (scale) val *= scale[gid * group_size + cid]; + if (bias) val += bias[gid * group_size + cid]; + y[(bid * C + ccid) * imsize + imid] = val; + } +} + +template +class GroupNormKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const float epsilon = ctx.Attr("epsilon"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); + auto* x = ctx.Input("X"); + + auto* y = ctx.Output("Y"); + auto* mean = ctx.Output("Mean"); + auto* var = ctx.Output("Variance"); + const auto groups = ctx.Attr("groups"); + + const auto x_dims = x->dims(); + const int group_size = (x_dims[1] - 1) / groups + 1; + + y->mutable_data(ctx.GetPlace()); + mean->mutable_data(ctx.GetPlace()); + var->mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + auto& dev_ctx = ctx.template device_context(); + Tensor temp_var; + temp_var.mutable_data(var->dims(), ctx.GetPlace()); + + set_zero(dev_ctx, mean, static_cast(0)); + set_zero(dev_ctx, &temp_var, static_cast(0)); + + auto* x_data = x->data(); + auto* y_data = y->data(); + auto* mean_data = mean->data(); + auto* var_data = var->data(); + auto* temp_var_data = temp_var.data(); + + const T* scale_data = nullptr; + if (scale) scale_data = scale->data(); + const T* bias_data = nullptr; + if (bias) bias_data = bias->data(); + + int imsize = x_dims[2] * x_dims[3]; + int block_size = std::min(512, imsize); + dim3 grid(group_size, groups, x_dims[0]); + dim3 threads(block_size, 1, 1); + GroupNormForwardGetMeanAndVar<<>>( + x_data, x_dims[0], x_dims[1], imsize, groups, group_size, mean_data, + temp_var_data); + GroupNormForward<<>>( + x_data, mean_data, temp_var_data, scale_data, bias_data, x_dims[0], + x_dims[1], imsize, groups, group_size, epsilon, y_data, var_data); + } +}; + +template +__global__ void GroupNormBackwardGetMeanAndVar( + const T* x, const T* mean, const T* var, const T* scale, const T* d_y, + int N, int C, int imsize, int groups, int group_size, T epsilon, T* d_x, + T* d_mean, T* d_var, T* d_scale, T* d_bias) { + int gid = blockIdx.y; + int cid = blockIdx.x; + int bid = blockIdx.z; + int number = min(group_size, static_cast(C - gid * group_size)); + int ccid = gid * group_size + cid; + if (ccid >= C) return; + T x_mean = mean[bid * groups + gid]; + T x_var = var[bid * groups + gid]; + T var_inv = 1.0 / sqrt(x_var + epsilon); + T d_var_inv = 0, d_x_mean = 0; + T d_mean_data = 0, d_var_data = 0, d_scale_data = 0, d_bias_data = 0; + + for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { + T tmp = x[(bid * C + ccid) * imsize + imid]; + T val = (tmp - x_mean) * var_inv; + T dval = d_y[(bid * C + ccid) * imsize + imid]; + if (d_bias) d_bias_data += dval; + if (d_scale) d_scale_data += val * dval; + if (scale) dval = dval * scale[ccid]; + d_var_data += (tmp - x_mean) * dval; + T d_tmp = dval * var_inv; + if (d_x) d_x[(bid * C + ccid) * imsize + imid] = d_tmp; + d_mean_data -= d_tmp; + } + + __shared__ T s_mem[4]; + if (threadIdx.x == 0) { + s_mem[0] = s_mem[1] = 0; + if (d_scale) s_mem[2] = 0; + if (d_bias) s_mem[3] = 0; + } + __syncthreads(); + paddle::platform::CudaAtomicAdd(&s_mem[0], d_mean_data); + paddle::platform::CudaAtomicAdd(&s_mem[1], d_var_data); + if (d_scale) paddle::platform::CudaAtomicAdd(&s_mem[2], d_scale_data); + if (d_bias) paddle::platform::CudaAtomicAdd(&s_mem[3], d_bias_data); + __syncthreads(); + if (threadIdx.x == 0) { + paddle::platform::CudaAtomicAdd(&d_mean[bid * groups + gid], s_mem[0]); + paddle::platform::CudaAtomicAdd(&d_var[bid * groups + gid], s_mem[1]); + if (d_scale) paddle::platform::CudaAtomicAdd(&d_scale[ccid], s_mem[2]); + if (d_bias) paddle::platform::CudaAtomicAdd(&d_bias[ccid], s_mem[3]); + } +} + +template +__global__ void GroupNormBackward(const T* x, const T* mean, const T* var, + const T* d_mean, const T* d_var, int N, int C, + int imsize, int groups, int group_size, + T epsilon, T* d_x) { + int gid = blockIdx.y; + int cid = blockIdx.x; + int bid = blockIdx.z; + int number = min(group_size, static_cast(C - gid * group_size)); + int ccid = gid * group_size + cid; + if (ccid >= C) return; + T x_mean = mean[bid * groups + gid]; + T x_var = var[bid * groups + gid]; + T d_x_mean = d_mean[bid * groups + gid]; + T d_var_inv = d_var[bid * groups + gid]; + + T d_x_var = + -1.0 / (2 * (x_var + epsilon) * sqrt(x_var + epsilon)) * d_var_inv; + d_x_mean -= 2 * d_x_var * x_mean; + d_x_var /= number * imsize; + d_x_mean /= number * imsize; + for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { + T tmp = x[(bid * C + ccid) * imsize + imid]; + if (d_x) + d_x[(bid * C + ccid) * imsize + imid] += d_x_mean + tmp * 2 * d_x_var; + } +} + +template +class GroupNormGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const float epsilon = ctx.Attr("epsilon"); + auto* x = ctx.Input("X"); + auto* mean = ctx.Input("Mean"); + auto* var = ctx.Input("Variance"); + auto* scale = ctx.Input("Scale"); + auto* d_y = ctx.Input(framework::GradVarName("Y")); + const auto groups = ctx.Attr("groups"); + + // init output + auto* d_x = ctx.Output(framework::GradVarName("X")); + auto* d_scale = ctx.Output(framework::GradVarName("Scale")); + auto* d_bias = ctx.Output(framework::GradVarName("Bias")); + + const auto& x_dims = x->dims(); + const int group_size = (x_dims[1] - 1) / groups + 1; + + T* d_x_data = nullptr; + if (d_x) { + d_x->mutable_data(ctx.GetPlace()); + d_x_data = d_x->data(); + } + math::SetConstant set_zero; + auto& dev_ctx = ctx.template device_context(); + + Tensor temp_var; + temp_var.mutable_data(var->dims(), ctx.GetPlace()); + set_zero(dev_ctx, &temp_var, static_cast(0)); + T* temp_var_data = temp_var.data(); + + Tensor temp_mean; + temp_mean.mutable_data(var->dims(), ctx.GetPlace()); + set_zero(dev_ctx, &temp_mean, static_cast(0)); + T* temp_mean_data = temp_mean.data(); + + auto* x_data = x->data(); + auto* y_data = d_y->data(); + auto* mean_data = mean->data(); + auto* var_data = var->data(); + T* d_scale_data = nullptr; + if (d_scale) { + d_scale->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, d_scale, static_cast(0)); + d_scale_data = d_scale->data(); + } + T* d_bias_data = nullptr; + if (d_bias) { + d_bias->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, d_bias, static_cast(0)); + d_bias_data = d_bias->data(); + } + + const T* scale_data = nullptr; + if (scale) scale_data = scale->data(); + + int imsize = x_dims[2] * x_dims[3]; + int block_size = std::min(512, imsize); + dim3 grid(group_size, groups, x_dims[0]); + dim3 threads(block_size, 1, 1); + GroupNormBackwardGetMeanAndVar<<>>( + x_data, mean_data, var_data, scale_data, y_data, x_dims[0], x_dims[1], + imsize, groups, group_size, epsilon, d_x_data, temp_mean_data, + temp_var_data, d_scale_data, d_bias_data); + GroupNormBackward<<>>( + x_data, mean_data, var_data, temp_mean_data, temp_var_data, x_dims[0], + x_dims[1], imsize, groups, group_size, epsilon, d_x_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + group_norm, + ops::GroupNormKernel, + ops::GroupNormKernel); +REGISTER_OP_CUDA_KERNEL( + group_norm_grad, + ops::GroupNormGradKernel, + ops::GroupNormGradKernel); diff --git a/paddle/fluid/operators/group_norm_op.h b/paddle/fluid/operators/group_norm_op.h new file mode 100644 index 0000000000000..3d6c6a46a9662 --- /dev/null +++ b/paddle/fluid/operators/group_norm_op.h @@ -0,0 +1,197 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DataLayout = framework::DataLayout; + +template +class GroupNormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const float epsilon = ctx.Attr("epsilon"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); + auto* x = ctx.Input("X"); + + auto* y = ctx.Output("Y"); + auto* mean = ctx.Output("Mean"); + auto* var = ctx.Output("Variance"); + const auto groups = ctx.Attr("groups"); + + const auto x_dims = x->dims(); + const int group_size = (x_dims[1] - 1) / groups + 1; + + y->mutable_data(ctx.GetPlace()); + mean->mutable_data(ctx.GetPlace()); + var->mutable_data(ctx.GetPlace()); + + auto* x_data = x->data(); + auto* y_data = y->data(); + auto* mean_data = mean->data(); + auto* var_data = var->data(); + + const T* scale_data = nullptr; + if (scale) scale_data = scale->data(); + const T* bias_data = nullptr; + if (bias) bias_data = bias->data(); + + int imsize = x_dims[2] * x_dims[3]; + auto* iter_x_data = x_data; + auto* iter_y_data = y_data; + for (int bid = 0; bid < x_dims[0]; bid++) + for (int gid = 0; gid < groups; gid++) { + T x_mean = 0, x_var = 0; + int number = std::min(group_size, + static_cast(x_dims[1] - gid * group_size)); + auto* tmp = iter_x_data; + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; imid++, iter_x_data++) { + x_mean += iter_x_data[0]; + x_var += iter_x_data[0] * iter_x_data[0]; + } + } + x_mean /= number * imsize; + x_var /= number * imsize; + x_var = x_var - x_mean * x_mean; + T var_inv = 1.0 / sqrt(x_var + epsilon); + mean_data[bid * groups + gid] = x_mean; + var_data[bid * groups + gid] = x_var; + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; imid++, tmp++, iter_y_data++) { + T val = (tmp[0] - x_mean) * var_inv; + if (scale_data) val *= scale_data[gid * group_size + cid]; + if (bias_data) val += bias_data[gid * group_size + cid]; + iter_y_data[0] = val; + } + } + } + } +}; + +template +class GroupNormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const float epsilon = ctx.Attr("epsilon"); + auto* x = ctx.Input("X"); + auto* mean = ctx.Input("Mean"); + auto* var = ctx.Input("Variance"); + auto* scale = ctx.Input("Scale"); + auto* d_y = ctx.Input(framework::GradVarName("Y")); + const auto groups = ctx.Attr("groups"); + + // init output + auto* d_x = ctx.Output(framework::GradVarName("X")); + auto* d_scale = ctx.Output(framework::GradVarName("Scale")); + auto* d_bias = ctx.Output(framework::GradVarName("Bias")); + + const auto& x_dims = x->dims(); + const int group_size = (x_dims[1] - 1) / groups + 1; + + // TODO(liangdun): need to check d_x is null + math::SetConstant set_zero; + auto& dev_ctx = ctx.template device_context(); + T* d_x_data = nullptr; + if (d_x) { + d_x->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, d_x, static_cast(0)); + d_x_data = d_x->data(); + } + + auto* x_data = x->data(); + auto* y_data = d_y->data(); + auto* mean_data = mean->data(); + auto* var_data = var->data(); + T* d_scale_data = nullptr; + if (d_scale) { + d_scale->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, d_scale, static_cast(0)); + d_scale_data = d_scale->data(); + } + T* d_bias_data = nullptr; + if (d_bias) { + d_bias->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, d_bias, static_cast(0)); + d_bias_data = d_bias->data(); + } + + const T* scale_data = nullptr; + if (scale) scale_data = scale->data(); + + int imsize = x_dims[2] * x_dims[3]; + auto* iter_x_data = x_data; + auto* iter_d_x_data = d_x_data; + auto* iter_y_data = y_data; + for (int bid = 0; bid < x_dims[0]; bid++) + for (int gid = 0; gid < groups; gid++) { + T x_mean = mean_data[bid * groups + gid]; + T x_var = var_data[bid * groups + gid]; + T var_inv = 1.0 / sqrt(x_var + epsilon); + int number = std::min(group_size, + static_cast(x_dims[1] - gid * group_size)); + auto* tmp = iter_x_data; + auto* tmp2 = iter_d_x_data; + T d_var_inv = 0, d_x_mean = 0; + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; + imid++, tmp++, iter_y_data++, iter_d_x_data++) { + T val = (tmp[0] - x_mean) * var_inv; + T dval = iter_y_data[0]; + if (d_bias_data) d_bias_data[gid * group_size + cid] += dval; + if (d_scale_data) + d_scale_data[gid * group_size + cid] += val * dval; + if (scale_data) dval = scale_data[gid * group_size + cid] * dval; + + d_var_inv += (tmp[0] - x_mean) * dval; + T d_tmp = dval * var_inv; + if (d_x_data) iter_d_x_data[0] += d_tmp; + d_x_mean -= d_tmp; + } + } + + T d_x_var = + -1.0 / (2 * (x_var + epsilon) * sqrt(x_var + epsilon)) * d_var_inv; + d_x_mean -= 2 * d_x_var * x_mean; + d_x_var /= number * imsize; + d_x_mean /= number * imsize; + + iter_d_x_data = tmp2; + + if (d_x_data) { + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; + imid++, iter_x_data++, iter_d_x_data++) { + iter_d_x_data[0] += d_x_mean; + iter_d_x_data[0] += iter_x_data[0] * 2 * d_x_var; + } + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e0cc09a4c76ca..ccd9175b64d46 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -85,6 +85,7 @@ 'row_conv', 'multiplex', 'layer_norm', + 'group_norm', 'softmax_with_cross_entropy', 'smooth_l1', 'one_hot', @@ -2547,6 +2548,84 @@ def layer_norm(input, return helper.append_activation(layer_norm_out) +@templatedoc() +def group_norm(input, + groups, + epsilon=1e-05, + param_attr=None, + bias_attr=None, + act=None, + data_layout='NCHW', + name=None): + """ + **Group Normalization Layer** + + Refer to `Group Normalization ` + + Args: + input(Variable): The input tensor variable. + groups(int): The number of groups that divided from channels. + epsilon(float): The small value added to the variance to prevent + division by zero. + param_attr(ParamAttr|None): The parameter attribute for the learnable + scale :math:`g`. If it is set to False, no scale will be added to the output units. + If it is set to None, the bias is initialized one. Default: None. + bias_attr(ParamAttr|None): The parameter attribute for the learnable + bias :math:`b`. If it is set to False, no bias will be added to the output units. + If it is set to None, the bias is initialized zero. Default: None. + act(str): Activation to be applied to the output of group normalizaiton. + data_layout(string|NCHW): Only NCHW is supported. + name (str): The name of this layer. It is optional. + + Returns: + Variable: A tensor variable which is the result after applying group normalization on the input. + + Examples: + + >>> data = fluid.layers.data(name='data', shape=[8, 32, 32], + >>> dtype='float32') + >>> x = fluid.layers.group_norm(input=data, groups=4) + """ + helper = LayerHelper('group_norm', **locals()) + dtype = helper.input_dtype() + + # create intput and parameters + inputs = {'X': input} + input_shape = input.shape + if data_layout != 'NCHW': + raise ValueError("unsupported data layout:" + data_layout) + param_shape = [input_shape[1]] + if param_attr: + scale = helper.create_parameter( + attr=helper.param_attr, + shape=param_shape, + dtype=dtype, + default_initializer=Constant(1.0)) + inputs['Scale'] = scale + if bias_attr: + bias = helper.create_parameter( + attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) + inputs['Bias'] = bias + + # create output + mean_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) + variance_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) + group_norm_out = helper.create_tmp_variable(dtype) + + helper.append_op( + type="group_norm", + inputs=inputs, + outputs={ + "Y": group_norm_out, + "Mean": mean_out, + "Variance": variance_out, + }, + attrs={"epsilon": epsilon, + "groups": groups}) + + return helper.append_activation(group_norm_out) + + def conv2d_transpose(input, num_filters, output_size=None, diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index c195a28e452fb..271b9c740fd99 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -381,8 +381,8 @@ def check_output_customized(self, checker): outs.sort(key=len) checker(outs) - def __assert_is_close(self, numeric_grads, analytic_grads, names, - max_relative_error, msg_prefix): + def _assert_is_close(self, numeric_grads, analytic_grads, names, + max_relative_error, msg_prefix): for a, b, name in six.moves.zip(numeric_grads, analytic_grads, names): abs_a = np.abs(a) @@ -451,9 +451,9 @@ def check_grad_with_place(self, analytic_grads = self._get_gradient(inputs_to_check, place, output_names, no_grad_set) - self.__assert_is_close(numeric_grads, analytic_grads, inputs_to_check, - max_relative_error, - "Gradient Check On %s" % str(place)) + self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check, + max_relative_error, + "Gradient Check On %s" % str(place)) @staticmethod def _numpy_to_lod_tensor(np_value, lod, place): diff --git a/python/paddle/fluid/tests/unittests/test_group_norm_op.py b/python/paddle/fluid/tests/unittests/test_group_norm_op.py new file mode 100644 index 0000000000000..0b6d039f05089 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_group_norm_op.py @@ -0,0 +1,143 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np + +from operator import mul +import paddle.fluid.core as core +import paddle.fluid as fluid +from op_test import OpTest + +from testsuite import create_op + + +def group_norm_naive(x, scale, bias, epsilon, groups): + N, C, H, W = x.shape + G = groups + x = x.reshape((N * G, -1)) + mean = np.mean(x, axis=1, keepdims=True) + var = np.var(x, axis=1, keepdims=True) + output = (x - mean) / np.sqrt(var + epsilon) + output = output.reshape((N, C, H, W)) * scale.reshape( + (-1, 1, 1)) + bias.reshape((-1, 1, 1)) + return output, mean.reshape((N, G)), var.reshape((N, G)) + + +class TestGroupNormOp(OpTest): + def setUp(self): + self.op_type = "group_norm" + self.data_format = "NCHW" + self.dtype = np.float32 + self.shape = (2, 4, 3, 3) + self.attrs = {'epsilon': 1e-5, 'groups': 2} + self.compare_between_place = False + self.init_test_case() + + input = np.random.random(self.shape).astype(self.dtype) + scale = np.random.random([self.shape[1]]).astype(self.dtype) + bias = np.random.random([self.shape[1]]).astype(self.dtype) + output, mean, var = group_norm_naive( + input, scale, bias, self.attrs['epsilon'], self.attrs['groups']) + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(input), + 'Scale': OpTest.np_dtype_to_fluid_dtype(scale), + 'Bias': OpTest.np_dtype_to_fluid_dtype(bias) + } + self.outputs = {'Y': output, 'Mean': mean, 'Variance': var} + + def test_check_output(self): + atol = 1e-4 + place = core.CPUPlace() + self.check_output_with_place(place, atol=atol) + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=atol) + + def do_compare_between_place(self): + if not core.is_compiled_with_cuda(): return + place = core.CPUPlace() + place2 = core.CUDAPlace(0) + self.scope = core.Scope() + op_inputs = self.inputs if hasattr(self, "inputs") else dict() + op_outputs = self.outputs if hasattr(self, "outputs") else dict() + op_attrs = self.attrs if hasattr(self, "attrs") else dict() + self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, + op_attrs) + inputs_to_check = set(['X', 'Scale', 'Bias']) + output_names = 'Y' + cpu_grads = self._get_gradient(inputs_to_check, place, output_names, + None) + gpu_grads = self._get_gradient(inputs_to_check, place2, output_names, + None) + self._assert_is_close(cpu_grads, gpu_grads, inputs_to_check, 0.005, + "Gradient Check On %s" % str(place)) + + def test_check_grad(self): + if self.compare_between_place: + self.do_compare_between_place() + return + place = core.CPUPlace() + self.check_grad_with_place( + place, set(['X', 'Scale', 'Bias']), 'Y', max_relative_error=0.01) + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + set(['X', 'Scale', 'Bias']), + 'Y', + max_relative_error=0.01) + + def init_test_case(self): + pass + + +class TestGroupNormOp1(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 1 + + +class TestGroupNormOp2(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 4 + + +class TestGroupNormOpBigEps1(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 1 + self.attrs['epsilon'] = 0.5 + + +class TestGroupNormOpBigEps2(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 4 + self.attrs['epsilon'] = 0.5 + + +class TestGroupNormOpBigEps3(TestGroupNormOp): + def init_test_case(self): + self.attrs['epsilon'] = 0.5 + + +class TestGroupNormOpLargeData(TestGroupNormOp): + def init_test_case(self): + self.shape = (2, 32, 64, 64) + self.attrs['groups'] = 8 + self.compare_between_place = True + + +if __name__ == '__main__': + unittest.main()