Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

max_pool2d CPU forward performance is poor #51393

Open
jamesr66a opened this issue Jan 30, 2021 · 2 comments
Open

max_pool2d CPU forward performance is poor #51393

jamesr66a opened this issue Jan 30, 2021 · 2 comments
Labels
module: performance Issues related to performance, either of kernel code or framework glue module: pooling triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jamesr66a
Copy link
Collaborator

jamesr66a commented Jan 30, 2021

import torch
import torch.fx

import torchvision.models as models
rn18 = models.resnet18()
rn18.eval()
rn18.requires_grad_(False)

N, C, H, W = 10, 3, 224, 224

def rn18_bench(input, chrome_trace_filename):
    with torch.autograd.profiler.profile(record_shapes=True) as prof:
        rn18(x)
    prof.export_chrome_trace(chrome_trace_filename)

x = torch.randn(N, C, H, W, requires_grad=False)
rn18_bench(x, 'maxpool_nchw.json')

x_nhwc = x.contiguous(memory_format=torch.channels_last)
rn18_bench(x_nhwc, 'maxpool_nhwc.json')

# Roofline analyze maxpool


def maxpool_bench(input, name):
    import time

    # warmup iteration. don't care if things get put in cache because we're not
    # even close to hitting bandwidth/cache bound
    rn18.maxpool(input)

    NITER = 100
    s = time.time()
    for _ in range(NITER):
        out = rn18.maxpool(input)
    e = time.time()

    time_per_iter_sec = (e - s) / NITER
    bytes_in = max_pool_input.numel() * max_pool_input.element_size()
    bytes_out = out.numel() * out.element_size()
    gbps = (bytes_in + bytes_out) / time_per_iter_sec / 1e9

    total_kernel_size = rn18.maxpool.kernel_size ** 2
    gflops = out.numel() * total_kernel_size / time_per_iter_sec / 1e9

    print(name, gbps, 'GB/s', gflops, 'GFLOP/s')

max_pool_input = torch.randn(10, 64, 112, 112, requires_grad=False)
maxpool_bench(max_pool_input, 'maxpool NCHW')

max_pool_input_nhwc = max_pool_input.contiguous(memory_format=torch.channels_last)
maxpool_bench(max_pool_input_nhwc, 'maxpool NHWC')

NCHW
image

NHWC
image

Results from running on my machine (lscpu)

maxpool NCHW 4.765286484398438 GB/s 2.1443789179792976 GFLOP/s
maxpool NHWC 3.958536155344408 GB/s 1.7813412699049835 GFLOP/s

These results are well below both peak memory/cache bandwidth on the machine and GFLOPs

perf indicates that we're hitting max_pool2d_with_indices_single_out_frame.

A few things pop out here:

  • We're unconditionally pulling out indices for backwards even if we're running this network in inference
  • There's no way this can be vectorized. In fact, disassembly indicates that it's not (ucomiss on scalar single-precision floats):
    image
  • No difference for channels-last memory layout. Making this fast for NCHW is understandably hard, but NHWC should be relatively trivial. The quantized op kernel does so:

cc @VitalyFedyunin @ngimel @heitorschueroff

@ngimel
Copy link
Collaborator

ngimel commented Jan 30, 2021

cc @heitorschueroff

@ngimel ngimel added module: performance Issues related to performance, either of kernel code or framework glue module: pooling triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 30, 2021
@VitalyFedyunin
Copy link
Contributor

#48917 - is targeting to improve the situation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: performance Issues related to performance, either of kernel code or framework glue module: pooling triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants