Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix perfornance issue of GroupNorm on CUDA when feature map is small. (…
…#46170) Summary: Pull Request resolved: #46170 Fix perfornance issue of GroupNorm on CUDA when feature map is small. Benchmark script: ``` import torch import torch.nn.functional as F from timeit import Timer norm = torch.nn.GroupNorm(8, 512).cuda() num = 5000 sizes = [(1024, 512, 14, 14), (1024, 512, 7, 7), (1024, 512)] def forward(x): _ = norm(x) torch.cuda.synchronize() def backward(y, grad): y.backward(grad, retain_graph=True) torch.cuda.synchronize() if __name__ == "__main__": # warm up x = torch.rand(*(sizes[0]), dtype=torch.float, device="cuda", requires_grad=True) for _ in range(100): forward(x) for size in sizes: x = torch.rand(*size, dtype=torch.float, device="cuda", requires_grad=True) t = Timer("forward(x)", "from __main__ import forward, x") print(f"size = {size}:") t1 = t.timeit(num) / num * 1e6 print(f"avg_forward_time = {t1}us") y = norm(x) grad = torch.randn_like(y) t = Timer("backward(y, grad)", "from __main__ import backward, y, grad") t2 = t.timeit(num) / num * 1e6 print(f"avg_backward_time = {t2}us") ``` Benchmark result before this Diff: ``` size = (1024, 512, 14, 14): avg_forward_time = 1636.729855206795us avg_backward_time = 5488.682465581223us size = (1024, 512, 7, 7): avg_forward_time = 465.88476160541177us avg_backward_time = 3129.9425506033003us size = (1024, 512): avg_forward_time = 96.90486900508404us avg_backward_time = 2319.4099438143894us ``` Benchmark result after this Diff: ``` size = (1024, 512, 14, 14): avg_forward_time = 1635.6191572034732us avg_backward_time = 4140.7730475999415us size = (1024, 512, 7, 7): avg_forward_time = 463.6513736099005us avg_backward_time = 1641.7451039887965us size = (1024, 512): avg_forward_time = 66.59087920561433us avg_backward_time = 128.6882139975205us ``` Test Plan: buck test mode/dev-nosan //caffe2/test:nn -- "GroupNorm" Differential Revision: D24242738 fbshipit-source-id: 56c0b5f381ac96cb539e9f01b8c504337a57cd9c
- Loading branch information