Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix perfornance issue of GroupNorm on CUDA when feature map is small. #46170

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions aten/src/ATen/cuda/DeviceUtils.cuh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#include <cuda.h>
#include <c10/util/complex.h>
#include <c10/util/Half.h>
Expand Down
36 changes: 18 additions & 18 deletions aten/src/ATen/native/cpu/group_norm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ void GroupNormKernelImplInternal(
int64_t HxW,
int64_t group,
T eps,
Tensor* Y,
Tensor* mean,
Tensor* rstd) {
Tensor& Y,
Tensor& mean,
Tensor& rstd) {
TORCH_CHECK(X.numel() == N * C * HxW);
TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
TORCH_CHECK(!beta.defined() || beta.numel() == C);
Expand All @@ -35,9 +35,9 @@ void GroupNormKernelImplInternal(
const T* X_data = X.data_ptr<T>();
const T* gamma_data = gamma.defined() ? gamma.data_ptr<T>() : nullptr;
const T* beta_data = beta.defined() ? beta.data_ptr<T>() : nullptr;
T* Y_data = Y->data_ptr<T>();
T* mean_data = mean->data_ptr<T>();
T* rstd_data = rstd->data_ptr<T>();
T* Y_data = Y.data_ptr<T>();
T* mean_data = mean.data_ptr<T>();
T* rstd_data = rstd.data_ptr<T>();
const T s = T(1) / static_cast<T>(D * HxW);
const bool gamma_null = (gamma_data == nullptr);
const bool beta_null = beta_data == nullptr;
Expand Down Expand Up @@ -94,9 +94,9 @@ void GroupNormKernelImpl(
int64_t HxW,
int64_t group,
double eps,
Tensor* Y,
Tensor* mean,
Tensor* rstd) {
Tensor& Y,
Tensor& mean,
Tensor& rstd) {
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GroupNormKernelImpl", [&]() {
GroupNormKernelImplInternal<scalar_t>(
X,
Expand Down Expand Up @@ -268,9 +268,9 @@ void GroupNormBackwardKernelImplInternal(
int64_t C,
int64_t HxW,
int64_t group,
Tensor* dX,
Tensor* dgamma,
Tensor* dbeta) {
Tensor& dX,
Tensor& dgamma,
Tensor& dbeta) {
TORCH_CHECK(dY.numel() == N * C * HxW);
TORCH_CHECK(X.numel() == N * C * HxW);
TORCH_CHECK(mean.numel() == N * group);
Expand All @@ -282,9 +282,9 @@ void GroupNormBackwardKernelImplInternal(
const T* mean_data = mean.data_ptr<T>();
const T* rstd_data = rstd.data_ptr<T>();
const T* gamma_data = gamma.defined() ? gamma.data_ptr<T>() : nullptr;
T* dX_data = dX->defined() ? dX->data_ptr<T>() : nullptr;
T* dgamma_data = dgamma->defined() ? dgamma->data_ptr<T>() : nullptr;
T* dbeta_data = dbeta->defined() ? dbeta->data_ptr<T>() : nullptr;
T* dX_data = dX.defined() ? dX.data_ptr<T>() : nullptr;
T* dgamma_data = dgamma.defined() ? dgamma.data_ptr<T>() : nullptr;
T* dbeta_data = dbeta.defined() ? dbeta.data_ptr<T>() : nullptr;
Tensor ds = at::empty({N, C}, X.options());
Tensor db = at::empty({N, C}, X.options());
T* ds_data = ds.data_ptr<T>();
Expand Down Expand Up @@ -326,9 +326,9 @@ void GroupNormBackwardKernelImpl(
int64_t C,
int64_t HxW,
int64_t group,
Tensor* dX,
Tensor* dgamma,
Tensor* dbeta) {
Tensor& dX,
Tensor& dgamma,
Tensor& dbeta) {
AT_DISPATCH_FLOATING_TYPES(
X.scalar_type(), "GroupNormBackwardKernelImpl", [&]() {
GroupNormBackwardKernelImplInternal<scalar_t>(
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/block_reduce.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <THC/THCDeviceUtils.cuh>
#include <ATen/cuda/DeviceUtils.cuh>

namespace at {
namespace native {
Expand Down