Skip to content

grid_sample backward pass performance scales poorly with input size #64977

Closed
@to-mi

Description

@to-mi

🐛 Bug

The backward pass of grid_sample (to get grad with regard to grid) depends heavily on the input size (at least in mode="bilinear"). I don't see why this should be the case, as the grid determines which pixels in the input affect the computation (but perhaps I'm mistaken?). It's also possible to do a grid sample implementation using basic PyTorch operations where the performance doesn't scale as badly with input size (though it's, of course, not as optimized otherwise as grid_sample).

To Reproduce

Code to time with different input sizes:

import timeit

import torch


def grid_sample_test(input, grid, backward):
    if backward and grid.grad is not None:
        grid.grad.zero_()
    samples = torch.nn.functional.grid_sample(
        input,
        grid,
        mode="bilinear",
        padding_mode="border",
        align_corners=True,
    )
    m = samples.mean()
    if backward:
        m.backward()

    return samples


_input = None
_grid = None
_backward = None

if __name__ == "__main__":
    torch.manual_seed(15)
    torch.set_num_threads(1)

    N = 100
    C = 2
    repeats = 100
    H_out = 13
    W_out = 13
    dtype = torch.double
    devices = ["cpu"]
    backwards = [False, True]

    input_sizes = [(30, 40), (300, 400), (1000, 1200)]

    grid_cpu = 2.0 * torch.rand((N, H_out, W_out, 2), dtype=dtype) - 1.0

    for input_size in input_sizes:
        H_in, W_in = input_size
        input_cpu = torch.rand(
            (1, C, H_in, W_in),
            requires_grad=False,
            dtype=dtype,
        ).expand((N, -1, -1, -1))

        for _backward in backwards:
            for device in devices:
                _grid = grid_cpu.clone().detach().to(device).requires_grad_(True)
                _input = input_cpu.to(device)

                t = timeit.timeit(
                    "grid_sample_test(_input, _grid, _backward)",
                    globals=globals(),
                    number=repeats,
                )
                print(
                    f"device={device:>4} backward={str(_backward):>5} input size={H_in:>4}x{W_in:<4}: {t:5.2f}"
                )

Example output, with last column being time in seconds:

device= cpu backward=False input size=  30x40  :  0.03
device= cpu backward= True input size=  30x40  :  0.54
device= cpu backward=False input size= 300x400 :  0.04
device= cpu backward= True input size= 300x400 :  5.19
device= cpu backward=False input size=1000x1200:  0.11
device= cpu backward= True input size=1000x1200: 48.56

Expected behavior

I would expect the performance to scale less drastically with input size.

Environment

collect_env.py output:

Collecting environment information...
PyTorch version: 1.9.0
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Pop!_OS 20.04 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.8.11 (default, Aug  3 2021, 15:09:35)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.11.0-7633-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA GeForce MX250
Nvidia driver version: 470.57.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.20.3
[pip3] torch==1.9.0
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.2.89              hfd86e86_1  
[conda] mkl                       2021.3.0           h06a4308_520  
[conda] mkl-service               2.4.0            py38h7f8727e_0  
[conda] mkl_fft                   1.3.0            py38h42c9631_2  
[conda] mkl_random                1.2.2            py38h51133e4_0  
[conda] mypy-extensions           0.4.3                    pypi_0    pypi
[conda] numpy                     1.20.3           py38hf144106_0  
[conda] numpy-base                1.20.3           py38h74d4b33_0  

Used conda env:

name: grid_sample
channels:
  - default
  - pytorch
dependencies:
  - python=3.8
  - pytorch=1.9.0
  - numpy
  - ipython
  - black

Additional context

I would also be interested in any (possibly temporary) workarounds.

cc @VitalyFedyunin @ngimel @heitorschueroff @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @mruberry @jbschlosser @walterddr

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: cpuCPU specific problem (e.g., perf, algorithm)module: interpolationmodule: nnRelated to torch.nnmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions