Skip to content

Commit

Permalink
Fix perfornance issue of GroupNorm on CUDA when feature map is small. (
Browse files Browse the repository at this point in the history
…#46170)

Summary:
Pull Request resolved: #46170

Fix perfornance issue of GroupNorm on CUDA when feature map is small.

Benchmark script:

```
import torch
import torch.nn.functional as F

from timeit import Timer

norm = torch.nn.GroupNorm(8, 512).cuda()

num = 5000

sizes = [(1024, 512, 14, 14), (1024, 512, 7, 7), (1024, 512)]

def forward(x):
    _ = norm(x)
    torch.cuda.synchronize()

def backward(y, grad):
    y.backward(grad, retain_graph=True)
    torch.cuda.synchronize()

if __name__ == "__main__":
    # warm up
    x = torch.rand(*(sizes[0]), dtype=torch.float,
                   device="cuda", requires_grad=True)
    for _ in range(100):
        forward(x)

    for size in sizes:
        x = torch.rand(*size, dtype=torch.float,
                       device="cuda", requires_grad=True)
        t = Timer("forward(x)", "from __main__ import forward, x")
        print(f"size = {size}:")
        t1 = t.timeit(num) / num * 1e6
        print(f"avg_forward_time =  {t1}us")

        y = norm(x)
        grad = torch.randn_like(y)
        t = Timer("backward(y, grad)", "from __main__ import backward, y, grad")
        t2 = t.timeit(num) / num * 1e6
        print(f"avg_backward_time = {t2}us")
```
Benchmark result before this Diff:
```
size = (1024, 512, 14, 14):
avg_forward_time =  1636.729855206795us
avg_backward_time = 5488.682465581223us
size = (1024, 512, 7, 7):
avg_forward_time =  465.88476160541177us
avg_backward_time = 3129.9425506033003us
size = (1024, 512):
avg_forward_time =  96.90486900508404us
avg_backward_time = 2319.4099438143894us
```

Benchmark result after this Diff:
```
size = (1024, 512, 14, 14):
avg_forward_time =  1635.6191572034732us
avg_backward_time = 4140.7730475999415us
size = (1024, 512, 7, 7):
avg_forward_time =  463.6513736099005us
avg_backward_time = 1641.7451039887965us
size = (1024, 512):
avg_forward_time =  66.59087920561433us
avg_backward_time = 128.6882139975205us

```

Test Plan: buck test mode/dev-nosan //caffe2/test:nn -- "GroupNorm"

Differential Revision: D24242738

fbshipit-source-id: 56c0b5f381ac96cb539e9f01b8c504337a57cd9c
  • Loading branch information
xiaomengy authored and facebook-github-bot committed Oct 14, 2020
1 parent 9d389b1 commit 1ccf848
Show file tree
Hide file tree
Showing 7 changed files with 674 additions and 262 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/cuda/DeviceUtils.cuh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#include <cuda.h>
#include <c10/util/complex.h>
#include <c10/util/Half.h>
Expand Down
36 changes: 18 additions & 18 deletions aten/src/ATen/native/cpu/group_norm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ void GroupNormKernelImplInternal(
int64_t HxW,
int64_t group,
T eps,
Tensor* Y,
Tensor* mean,
Tensor* rstd) {
Tensor& Y,
Tensor& mean,
Tensor& rstd) {
TORCH_CHECK(X.numel() == N * C * HxW);
TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
TORCH_CHECK(!beta.defined() || beta.numel() == C);
Expand All @@ -35,9 +35,9 @@ void GroupNormKernelImplInternal(
const T* X_data = X.data_ptr<T>();
const T* gamma_data = gamma.defined() ? gamma.data_ptr<T>() : nullptr;
const T* beta_data = beta.defined() ? beta.data_ptr<T>() : nullptr;
T* Y_data = Y->data_ptr<T>();
T* mean_data = mean->data_ptr<T>();
T* rstd_data = rstd->data_ptr<T>();
T* Y_data = Y.data_ptr<T>();
T* mean_data = mean.data_ptr<T>();
T* rstd_data = rstd.data_ptr<T>();
const T s = T(1) / static_cast<T>(D * HxW);
const bool gamma_null = (gamma_data == nullptr);
const bool beta_null = beta_data == nullptr;
Expand Down Expand Up @@ -94,9 +94,9 @@ void GroupNormKernelImpl(
int64_t HxW,
int64_t group,
double eps,
Tensor* Y,
Tensor* mean,
Tensor* rstd) {
Tensor& Y,
Tensor& mean,
Tensor& rstd) {
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GroupNormKernelImpl", [&]() {
GroupNormKernelImplInternal<scalar_t>(
X,
Expand Down Expand Up @@ -268,9 +268,9 @@ void GroupNormBackwardKernelImplInternal(
int64_t C,
int64_t HxW,
int64_t group,
Tensor* dX,
Tensor* dgamma,
Tensor* dbeta) {
Tensor& dX,
Tensor& dgamma,
Tensor& dbeta) {
TORCH_CHECK(dY.numel() == N * C * HxW);
TORCH_CHECK(X.numel() == N * C * HxW);
TORCH_CHECK(mean.numel() == N * group);
Expand All @@ -282,9 +282,9 @@ void GroupNormBackwardKernelImplInternal(
const T* mean_data = mean.data_ptr<T>();
const T* rstd_data = rstd.data_ptr<T>();
const T* gamma_data = gamma.defined() ? gamma.data_ptr<T>() : nullptr;
T* dX_data = dX->defined() ? dX->data_ptr<T>() : nullptr;
T* dgamma_data = dgamma->defined() ? dgamma->data_ptr<T>() : nullptr;
T* dbeta_data = dbeta->defined() ? dbeta->data_ptr<T>() : nullptr;
T* dX_data = dX.defined() ? dX.data_ptr<T>() : nullptr;
T* dgamma_data = dgamma.defined() ? dgamma.data_ptr<T>() : nullptr;
T* dbeta_data = dbeta.defined() ? dbeta.data_ptr<T>() : nullptr;
Tensor ds = at::empty({N, C}, X.options());
Tensor db = at::empty({N, C}, X.options());
T* ds_data = ds.data_ptr<T>();
Expand Down Expand Up @@ -326,9 +326,9 @@ void GroupNormBackwardKernelImpl(
int64_t C,
int64_t HxW,
int64_t group,
Tensor* dX,
Tensor* dgamma,
Tensor* dbeta) {
Tensor& dX,
Tensor& dgamma,
Tensor& dbeta) {
AT_DISPATCH_FLOATING_TYPES(
X.scalar_type(), "GroupNormBackwardKernelImpl", [&]() {
GroupNormBackwardKernelImplInternal<scalar_t>(
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/block_reduce.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <THC/THCDeviceUtils.cuh>
#include <ATen/cuda/DeviceUtils.cuh>

namespace at {
namespace native {
Expand Down

0 comments on commit 1ccf848

Please sign in to comment.