Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Value of torch.backends.cudnn.benchmark Baked into JIT-Traced Modules ( 150x slowdown on ConvTranspose2d() ) [jit] [libtorch] [cudnn] #18776

Open
HapeMask opened this issue Apr 2, 2019 · 7 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@HapeMask
Copy link
Contributor

HapeMask commented Apr 2, 2019

馃悰 Bug

If you trace a module with torch.jit.trace(...) and load that script module in C++ via LibTorch, the resulting behavior in C++ depends on whether or not the torch.backends.cudnn.benchmark flag was set. Calls to at::globalContext().setBenchmarkCuDNN(true/false) from the C++ API at runtime appear to have no effect.

To Reproduce

NOTE: I was not able to verify this issue still exists on the latest nightly (20190402) because it appears the latest nightly (at least on Windows) cannot run JIT-traced models. Even the simplest model gives the following error:

INVALID_ARGUMENT:: Cannot find field. (deserialize at ..\torch\csrc\jit\import.cpp:108)
(no backtrace available)
  1. Run the python script below: python test.py 0 or python test.py 1
  2. Compile + run the C++ code below.
  3. Observe:
    a) Average time per call. I see ~0.8ms in the python script and either ~0.8 or ~120ms in C++ depending on the flag used in python. In either case, C++ sets benchmarking ON. (GTX 1080)
    b) Kernel run by CuDNN. w/either setting of the flag, the python code runs cudnn::detail::dgrad_engine<...>. With the flag ON, it runs cudnn::detail::dgrad2d_alg1_1<...> once (taking ~120ms) and then chooses the faster dgrad_engine. If the flag was ON in python, C++ also chooses dgrad_engine but if the flag was OFF in python, it always chooses dgrad2d_alg1_1 regardless of the flag setting in C++.

I observed the choice of kernel using nvprof python test.py 0/1.

Python Script (test.py):

import sys
import time

import torch as th
th.backends.cudnn.benchmark = bool(int(sys.argv[1]))

mod = th.nn.ConvTranspose2d(8, 3, 4, 2, 1).cuda()
inp = th.zeros(1, 8, 512, 512).cuda()

mod(inp); mod(inp); mod(inp)

smod = th.jit.trace(mod, (inp,), check_trace=False)
smod.save("smod.ptj")

N = 1000
th.cuda.synchronize()
start = time.time()
for _ in range(N):
    mod(inp)
    th.cuda.synchronize()
end = time.time()
print("Time (ms):", 1000*(end-start)/N)

C++ Code:

#include <chrono>
#include <iostream>

#include <c10/cuda/CUDAGuard.h>
#include <torch/script.h>
#include <torch/torch.h>

#include <cuda_runtime_api.h>

int main() {
  at::globalContext().setBenchmarkCuDNN(true);
  auto nograd = torch::NoGradGuard();

  try {
    auto mod = torch::jit::load("smod.ptj");
    mod->to(torch::kCUDA);
    torch::jit::Stack input_stack = {torch::zeros({1, 8, 512, 512}, torch::kCUDA)};

    mod->forward(input_stack);
    mod->forward(input_stack);
    mod->forward(input_stack);

    const int N = 100;
    cudaDeviceSynchronize();
    const auto start = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < N; ++i) {
      mod->forward(input_stack);
      cudaDeviceSynchronize();
    }
    const auto end = std::chrono::high_resolution_clock::now();
    const float elapsed = std::chrono::duration<float, std::milli>(end - start).count() / N;
    std::cout << "Time (ms): " << elapsed << std::endl;
  } catch (c10::Error e) {
    std::cerr << e.what() << std::endl;
    return 1;
  }
  return 0;
}

Expected Behavior

I would expect that either: 1) the C++ setting of at::globalContext().setBenchmarkCuDNN(true) should be respected (choosing the correct algorithm) or 2) at least print a warning that it is being overridden by the value of the flag at trace time.

Additional Info

I printed the JIT graphs generated with benchmarking ON/OFF and got the following with the flag OFF:

graph(%input : Float(1, 8, 512, 512),
      %1 : Float(8, 3, 4, 4),
      %2 : Float(3)):
  %3 : int = prim::Constant[value=2](), scope: ConvTranspose2d
  %4 : int = prim::Constant[value=2](), scope: ConvTranspose2d
  %5 : int[] = prim::ListConstruct(%3, %4), scope: ConvTranspose2d
  %6 : int = prim::Constant[value=1](), scope: ConvTranspose2d
  %7 : int = prim::Constant[value=1](), scope: ConvTranspose2d
  %8 : int[] = prim::ListConstruct(%6, %7), scope: ConvTranspose2d
  %9 : int = prim::Constant[value=1](), scope: ConvTranspose2d
  %10 : int = prim::Constant[value=1](), scope: ConvTranspose2d
  %11 : int[] = prim::ListConstruct(%9, %10), scope: ConvTranspose2d
  %12 : bool = prim::Constant[value=1](), scope: ConvTranspose2d
  %13 : int = prim::Constant[value=0](), scope: ConvTranspose2d
  %14 : int = prim::Constant[value=0](), scope: ConvTranspose2d
  %15 : int[] = prim::ListConstruct(%13, %14), scope: ConvTranspose2d
  %16 : int = prim::Constant[value=1](), scope: ConvTranspose2d
  %17 : bool = prim::Constant[value=0](), scope: ConvTranspose2d
  %18 : bool = prim::Constant[value=0](), scope: ConvTranspose2d
  %19 : bool = prim::Constant[value=1](), scope: ConvTranspose2d
  %20 : Float(1, 3, 1024, 1024) = aten::_convolution(%input, %1, %2, %5, %8, %11, %12, %15, %16, %17, %18, %19), scope: ConvTranspose2d
  return (%20)

The only change when the flag is ON is that register %17 is 1 instead of 0. I suppose this is where the "hardcoding" of the flag might be happening?

Environment

Python code was run on Linux, C++ code was run on Windows

  • PyTorch Version (e.g., 1.0): 1.0.0.dev20190311 on linux, 2336f0b on Windows
  • OS (e.g., Linux): Fedora 29, Windows 10 1809
  • How you installed PyTorch (conda, pip, source): conda (pytorch-nightly)
  • Python version: 3.7
  • CUDA/cuDNN version: CUDA 10, cuDNN 7.4.2
  • GPU models and configuration: Titan RTX (linux), GTX 1080 (windows)

cc @suo

@pytorchbot pytorchbot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Apr 2, 2019
@ailzhang
Copy link
Contributor

ailzhang commented Apr 4, 2019

Yea in the current implementation it's expected behavior for the loaded module using a flag as constant set in trace phase. We could potentially add a warning for it.
On the other hand, is it possible for us to detect these backend related flags in runtime, and remove them as an input argument for _convolution? cc: @ngimel @apaszke @zdevito

@ezyang
Copy link
Contributor

ezyang commented Apr 4, 2019

If you adjust it to not trace _convolution, and instead trace the function that wraps around it, you won't hardcode the flag.

@ezyang
Copy link
Contributor

ezyang commented Apr 4, 2019

I think the way the code is currently structured, there are a lot of call sites of _convolution, so this may be annoying to do. One possibility is to add another intermediate function between those call sites and _convolution which doesn't have the backend flag.

@ailzhang
Copy link
Contributor

ailzhang commented Apr 4, 2019

@ezyang In tracing we deliberately not to trace the function wraps around _convolution here.

DONT_RECORD_TRACE = {

I'm not very familiar with the reason behind it though.

@soumith
Copy link
Member

soumith commented Apr 4, 2019

we need the backend flag to compute backward correctly.

@apaszke
Copy link
Contributor

apaszke commented Apr 7, 2019

@soumith we don't. I wrote symbolic AD for batch norm which allows for multiple backends. I think the idea was that someone might have code like:

torch.backends.cudnn.enabled = False
some_op() # cuDNN bug
torch.backends.cudnn.enabled = True

And we wanted to preserve this semantics. I think we should save those flags when tracing.

@jamesr66a jamesr66a added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Oct 7, 2019
bzinodev added a commit that referenced this issue Oct 15, 2019
    Currently tracing convolution operations wraps to internal implementation
    _convolution function. As a result 'benchmark' and other flags are hard
    wired. Changes to these context flags at later stages are ignored.

    The proposed fix is to remove 'convolution' from DONT_RECORD_TRACE.
    Note all convolution operations end up calling convolution
    function so this fix applies to all convolution operations.

[ghstack-poisoned]
bzinodev added a commit that referenced this issue Oct 15, 2019
    Currently tracing convolution operations wraps to internal implementation
    _convolution function. As a result 'benchmark' and other flags are hard
    wired. Changes to these context flags at later stages are ignored.

    The proposed fix is to remove 'convolution' from DONT_RECORD_TRACE.
    Note all convolution operations end up calling convolution
    function so this fix applies to all convolution operations.

ghstack-source-id: 8368ea9b953e01f016cc1405687778706363136e
Pull Request resolved: #28027
@zasdfgbnm
Copy link
Collaborator

zasdfgbnm commented May 11, 2020

Can we distinguish the case where the user explicitly set the flag during the run of the model, from the case where the flag is set outside the model?

In specific, I mean, we should trace/script torch.backends.cudnn.enabled = False as an operator aten::setBenchmarkCuDNN(false) or prim::setBenchmarkCuDNN(true), and remove the benchmark argument from aten::convolution.

In such design

torch.backends.cudnn.enabled = False

@script
def model(x):
    return F.convolution(x, .....)

will respect whatever the flag is set at the context where the model is executed, while

@script
def model(x):
    torch.backends.cudnn.enabled = False
    return F.convolution(x, .....)

will always run without cuDNN regardless of the context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants