Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix perfornance issue of GroupNorm on CUDA when feature map is small. #46170

Closed
wants to merge 1 commit into from

Conversation

xiaomengy
Copy link
Contributor

@xiaomengy xiaomengy commented Oct 11, 2020

Summary: Fix perfornance issue of GroupNorm on CUDA when feature map is small.

Test Plan: buck test mode/dev-nosan //caffe2/test:nn -- "GroupNorm"

Differential Revision: D24242738

As mentioned in #46086, the current GroupNorm implementation performs bad on CUDA when the feature map is small even compared to the impl via BatchNorm before PyTorch 1.5.1. This PR fixed the performance issue when the feature map is small.

Benchmark script:

import torch
import torch.nn.functional as F

from timeit import Timer

norm = torch.nn.GroupNorm(8, 512).cuda()

num = 5000

sizes = [(1024, 512, 14, 14), (1024, 512, 7, 7), (1024, 512)]


def forward(x):
    _ = norm(x)
    torch.cuda.synchronize()


def backward(y, grad):
    y.backward(grad, retain_graph=True)
    torch.cuda.synchronize()


if __name__ == "__main__":
    # warm up
    x = torch.rand(*(sizes[0]), dtype=torch.float,
                   device="cuda", requires_grad=True)
    for _ in range(100):
        forward(x)

    for size in sizes:
        x = torch.rand(*size, dtype=torch.float,
                       device="cuda", requires_grad=True)
        t = Timer("forward(x)", "from __main__ import forward, x")
        print(f"size = {size}:")
        t1 = t.timeit(num) / num * 1e6
        print(f"avg_forward_time =  {t1}us")

        y = norm(x)
        grad = torch.randn_like(y)
        t = Timer("backward(y, grad)", "from __main__ import backward, y, grad")
        t2 = t.timeit(num) / num * 1e6
        print(f"avg_backward_time = {t2}us")

Benchmark result after this PR on a V100 devgpu:

size = (1024, 512, 14, 14):
avg_forward_time =  1635.6191572034732us
avg_backward_time = 4140.7730475999415us
size = (1024, 512, 7, 7):
avg_forward_time =  463.6513736099005us
avg_backward_time = 1641.7451039887965us
size = (1024, 512):
avg_forward_time =  66.59087920561433us
avg_backward_time = 128.6882139975205us

Benchmark result before this PR on a V100 devgpu:

size = (1024, 512, 14, 14):
avg_forward_time =  1636.729855206795us
avg_backward_time = 5488.682465581223us
size = (1024, 512, 7, 7):
avg_forward_time =  465.88476160541177us
avg_backward_time = 3129.9425506033003us
size = (1024, 512):
avg_forward_time =  96.90486900508404us
avg_backward_time = 2319.4099438143894us

Run this benchmark script on PyTorch 1.5.1 (the build options may be different, just for reference):

size = (1024, 512, 14, 14):
avg_forward_time =  2728.9786524139345us
avg_backward_time = 9711.360842408612us
size = (1024, 512, 7, 7):
avg_forward_time =  773.7861637957394us
avg_backward_time = 2496.8661199789494us
size = (1024, 512):
avg_forward_time =  173.00677900202572us
avg_backward_time = 188.9778567943722us

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24242738

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a very nice performance improvement. It would be good to have the comments in the kernels and describe how you choose which path to use. Please make sure that you are testing all the kernel variants that you add (I'm not sure added tests cover everything). Also, can you enable and test bfloat16 dispatch on cuda? Hopefully it should just work, and we are enabling most bfloat16 operations now.
Thanks again for the fix!

aten/src/ATen/native/cuda/group_norm_kernel.cu Outdated Show resolved Hide resolved
torch/testing/_internal/common_nn.py Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Oct 12, 2020

Codecov Report

Merging #46170 into master will increase coverage by 0.00%.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master   #46170   +/-   ##
=======================================
  Coverage   68.32%   68.33%           
=======================================
  Files         410      410           
  Lines       53793    53793           
=======================================
+ Hits        36756    36757    +1     
+ Misses      17037    17036    -1     
Impacted Files Coverage Δ
torch/testing/_internal/common_nn.py 85.53% <ø> (ø)
torch/testing/_internal/expecttest.py 78.57% <0.00%> (+1.02%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9d389b1...1ccf848. Read the comment docs.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24242738

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24242738

@xiaomengy
Copy link
Contributor Author

That's a very nice performance improvement. It would be good to have the comments in the kernels and describe how you choose which path to use. Please make sure that you are testing all the kernel variants that you add (I'm not sure added tests cover everything). Also, can you enable and test bfloat16 dispatch on cuda? Hopefully it should just work, and we are enabling most bfloat16 operations now.
Thanks again for the fix!

BFloat16 has been enabled.

@dr-ci
Copy link

dr-ci bot commented Oct 13, 2020

💊 CI failures summary and remediations

As of commit 1ccf848 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 27 times.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24242738

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24242738

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24242738

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24242738

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24242738

…pytorch#46170)

Summary:
Pull Request resolved: pytorch#46170

Fix perfornance issue of GroupNorm on CUDA when feature map is small.

Benchmark script:

```
import torch
import torch.nn.functional as F

from timeit import Timer

norm = torch.nn.GroupNorm(8, 512).cuda()

num = 5000

sizes = [(1024, 512, 14, 14), (1024, 512, 7, 7), (1024, 512)]

def forward(x):
    _ = norm(x)
    torch.cuda.synchronize()

def backward(y, grad):
    y.backward(grad, retain_graph=True)
    torch.cuda.synchronize()

if __name__ == "__main__":
    # warm up
    x = torch.rand(*(sizes[0]), dtype=torch.float,
                   device="cuda", requires_grad=True)
    for _ in range(100):
        forward(x)

    for size in sizes:
        x = torch.rand(*size, dtype=torch.float,
                       device="cuda", requires_grad=True)
        t = Timer("forward(x)", "from __main__ import forward, x")
        print(f"size = {size}:")
        t1 = t.timeit(num) / num * 1e6
        print(f"avg_forward_time =  {t1}us")

        y = norm(x)
        grad = torch.randn_like(y)
        t = Timer("backward(y, grad)", "from __main__ import backward, y, grad")
        t2 = t.timeit(num) / num * 1e6
        print(f"avg_backward_time = {t2}us")
```
Benchmark result before this Diff:
```
size = (1024, 512, 14, 14):
avg_forward_time =  1636.729855206795us
avg_backward_time = 5488.682465581223us
size = (1024, 512, 7, 7):
avg_forward_time =  465.88476160541177us
avg_backward_time = 3129.9425506033003us
size = (1024, 512):
avg_forward_time =  96.90486900508404us
avg_backward_time = 2319.4099438143894us
```

Benchmark result after this Diff:
```
size = (1024, 512, 14, 14):
avg_forward_time =  1635.6191572034732us
avg_backward_time = 4140.7730475999415us
size = (1024, 512, 7, 7):
avg_forward_time =  463.6513736099005us
avg_backward_time = 1641.7451039887965us
size = (1024, 512):
avg_forward_time =  66.59087920561433us
avg_backward_time = 128.6882139975205us

```

Test Plan: buck test mode/dev-nosan //caffe2/test:nn -- "GroupNorm"

Differential Revision: D24242738

fbshipit-source-id: 56c0b5f381ac96cb539e9f01b8c504337a57cd9c
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24242738

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thank you!

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in a87a1c1.

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in a87a1c1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants